18c08f21bSMogball //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
28c08f21bSMogball //
38c08f21bSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
48c08f21bSMogball // See https://llvm.org/LICENSE.txt for license information.
58c08f21bSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
68c08f21bSMogball //
78c08f21bSMogball //===----------------------------------------------------------------------===//
88c08f21bSMogball 
91fc096afSMehdi Amini #include <utility>
101fc096afSMehdi Amini 
118c08f21bSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
128c08f21bSMogball #include "mlir/Dialect/CommonFolders.h"
138c08f21bSMogball #include "mlir/IR/Builders.h"
148c08f21bSMogball #include "mlir/IR/Matchers.h"
158c08f21bSMogball #include "mlir/IR/OpImplementation.h"
168c08f21bSMogball #include "mlir/IR/PatternMatch.h"
178c08f21bSMogball #include "mlir/IR/TypeUtilities.h"
1806057248SRiver Riddle #include "llvm/ADT/SmallString.h"
198c08f21bSMogball 
20ca8997ebSWilliam S. Moses #include "llvm/ADT/APSInt.h"
21ca8997ebSWilliam S. Moses 
228c08f21bSMogball using namespace mlir;
238c08f21bSMogball using namespace mlir::arith;
248c08f21bSMogball 
258c08f21bSMogball //===----------------------------------------------------------------------===//
268c08f21bSMogball // Pattern helpers
278c08f21bSMogball //===----------------------------------------------------------------------===//
288c08f21bSMogball 
addIntegerAttrs(PatternRewriter & builder,Value res,Attribute lhs,Attribute rhs)298c08f21bSMogball static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
308c08f21bSMogball                                    Attribute lhs, Attribute rhs) {
318c08f21bSMogball   return builder.getIntegerAttr(res.getType(),
328c08f21bSMogball                                 lhs.cast<IntegerAttr>().getInt() +
338c08f21bSMogball                                     rhs.cast<IntegerAttr>().getInt());
348c08f21bSMogball }
358c08f21bSMogball 
subIntegerAttrs(PatternRewriter & builder,Value res,Attribute lhs,Attribute rhs)368c08f21bSMogball static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
378c08f21bSMogball                                    Attribute lhs, Attribute rhs) {
388c08f21bSMogball   return builder.getIntegerAttr(res.getType(),
398c08f21bSMogball                                 lhs.cast<IntegerAttr>().getInt() -
408c08f21bSMogball                                     rhs.cast<IntegerAttr>().getInt());
418c08f21bSMogball }
428c08f21bSMogball 
438c08f21bSMogball /// Invert an integer comparison predicate.
invertPredicate(arith::CmpIPredicate pred)4497567bdeSWilliam S. Moses arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
458c08f21bSMogball   switch (pred) {
468c08f21bSMogball   case arith::CmpIPredicate::eq:
478c08f21bSMogball     return arith::CmpIPredicate::ne;
488c08f21bSMogball   case arith::CmpIPredicate::ne:
498c08f21bSMogball     return arith::CmpIPredicate::eq;
508c08f21bSMogball   case arith::CmpIPredicate::slt:
518c08f21bSMogball     return arith::CmpIPredicate::sge;
528c08f21bSMogball   case arith::CmpIPredicate::sle:
538c08f21bSMogball     return arith::CmpIPredicate::sgt;
548c08f21bSMogball   case arith::CmpIPredicate::sgt:
558c08f21bSMogball     return arith::CmpIPredicate::sle;
568c08f21bSMogball   case arith::CmpIPredicate::sge:
578c08f21bSMogball     return arith::CmpIPredicate::slt;
588c08f21bSMogball   case arith::CmpIPredicate::ult:
598c08f21bSMogball     return arith::CmpIPredicate::uge;
608c08f21bSMogball   case arith::CmpIPredicate::ule:
618c08f21bSMogball     return arith::CmpIPredicate::ugt;
628c08f21bSMogball   case arith::CmpIPredicate::ugt:
638c08f21bSMogball     return arith::CmpIPredicate::ule;
648c08f21bSMogball   case arith::CmpIPredicate::uge:
658c08f21bSMogball     return arith::CmpIPredicate::ult;
668c08f21bSMogball   }
678c08f21bSMogball   llvm_unreachable("unknown cmpi predicate kind");
688c08f21bSMogball }
698c08f21bSMogball 
invertPredicate(arith::CmpIPredicateAttr pred)708c08f21bSMogball static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
718c08f21bSMogball   return arith::CmpIPredicateAttr::get(pred.getContext(),
728c08f21bSMogball                                        invertPredicate(pred.getValue()));
738c08f21bSMogball }
748c08f21bSMogball 
758c08f21bSMogball //===----------------------------------------------------------------------===//
768c08f21bSMogball // TableGen'd canonicalization patterns
778c08f21bSMogball //===----------------------------------------------------------------------===//
788c08f21bSMogball 
798c08f21bSMogball namespace {
808c08f21bSMogball #include "ArithmeticCanonicalization.inc"
81be0a7e9fSMehdi Amini } // namespace
828c08f21bSMogball 
838c08f21bSMogball //===----------------------------------------------------------------------===//
84a54f4eaeSMogball // ConstantOp
85a54f4eaeSMogball //===----------------------------------------------------------------------===//
86a54f4eaeSMogball 
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)87a54f4eaeSMogball void arith::ConstantOp::getAsmResultNames(
88a54f4eaeSMogball     function_ref<void(Value, StringRef)> setNameFn) {
89a54f4eaeSMogball   auto type = getType();
90cfb72fd3SJacques Pienaar   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91a54f4eaeSMogball     auto intType = type.dyn_cast<IntegerType>();
92a54f4eaeSMogball 
93a54f4eaeSMogball     // Sugar i1 constants with 'true' and 'false'.
94a54f4eaeSMogball     if (intType && intType.getWidth() == 1)
95a54f4eaeSMogball       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96a54f4eaeSMogball 
97*1ef32e78SMarius Hillenbrand     // Otherwise, build a complex name with the value and type.
98a54f4eaeSMogball     SmallString<32> specialNameBuffer;
99a54f4eaeSMogball     llvm::raw_svector_ostream specialName(specialNameBuffer);
100*1ef32e78SMarius Hillenbrand     specialName << 'c' << intCst.getValue();
101a54f4eaeSMogball     if (intType)
102a54f4eaeSMogball       specialName << '_' << type;
103a54f4eaeSMogball     setNameFn(getResult(), specialName.str());
104a54f4eaeSMogball   } else {
105a54f4eaeSMogball     setNameFn(getResult(), "cst");
106a54f4eaeSMogball   }
107a54f4eaeSMogball }
108a54f4eaeSMogball 
109a54f4eaeSMogball /// TODO: disallow arith.constant to return anything other than signless integer
110a54f4eaeSMogball /// or float like.
verify()1111be88f5aSRiver Riddle LogicalResult arith::ConstantOp::verify() {
1121be88f5aSRiver Riddle   auto type = getType();
113a54f4eaeSMogball   // The value's type must match the return type.
1141be88f5aSRiver Riddle   if (getValue().getType() != type) {
1151be88f5aSRiver Riddle     return emitOpError() << "value type " << getValue().getType()
116a54f4eaeSMogball                          << " must match return type: " << type;
117a54f4eaeSMogball   }
118a54f4eaeSMogball   // Integer values must be signless.
119a54f4eaeSMogball   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
1201be88f5aSRiver Riddle     return emitOpError("integer return type must be signless");
121a54f4eaeSMogball   // Any float or elements attribute are acceptable.
1221be88f5aSRiver Riddle   if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
1231be88f5aSRiver Riddle     return emitOpError(
124a54f4eaeSMogball         "value must be an integer, float, or elements attribute");
125a54f4eaeSMogball   }
126a54f4eaeSMogball   return success();
127a54f4eaeSMogball }
128a54f4eaeSMogball 
isBuildableWith(Attribute value,Type type)129a54f4eaeSMogball bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130a54f4eaeSMogball   // The value's type must be the same as the provided type.
131a54f4eaeSMogball   if (value.getType() != type)
132a54f4eaeSMogball     return false;
133a54f4eaeSMogball   // Integer values must be signless.
134a54f4eaeSMogball   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135a54f4eaeSMogball     return false;
136a54f4eaeSMogball   // Integer, float, and element attributes are buildable.
137a54f4eaeSMogball   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138a54f4eaeSMogball }
139a54f4eaeSMogball 
fold(ArrayRef<Attribute> operands)140a54f4eaeSMogball OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141cfb72fd3SJacques Pienaar   return getValue();
142a54f4eaeSMogball }
143a54f4eaeSMogball 
build(OpBuilder & builder,OperationState & result,int64_t value,unsigned width)144a54f4eaeSMogball void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145a54f4eaeSMogball                                  int64_t value, unsigned width) {
146a54f4eaeSMogball   auto type = builder.getIntegerType(width);
147a54f4eaeSMogball   arith::ConstantOp::build(builder, result, type,
148a54f4eaeSMogball                            builder.getIntegerAttr(type, value));
149a54f4eaeSMogball }
150a54f4eaeSMogball 
build(OpBuilder & builder,OperationState & result,int64_t value,Type type)151a54f4eaeSMogball void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152a54f4eaeSMogball                                  int64_t value, Type type) {
153a54f4eaeSMogball   assert(type.isSignlessInteger() &&
154a54f4eaeSMogball          "ConstantIntOp can only have signless integer type values");
155a54f4eaeSMogball   arith::ConstantOp::build(builder, result, type,
156a54f4eaeSMogball                            builder.getIntegerAttr(type, value));
157a54f4eaeSMogball }
158a54f4eaeSMogball 
classof(Operation * op)159a54f4eaeSMogball bool arith::ConstantIntOp::classof(Operation *op) {
160a54f4eaeSMogball   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161a54f4eaeSMogball     return constOp.getType().isSignlessInteger();
162a54f4eaeSMogball   return false;
163a54f4eaeSMogball }
164a54f4eaeSMogball 
build(OpBuilder & builder,OperationState & result,const APFloat & value,FloatType type)165a54f4eaeSMogball void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166a54f4eaeSMogball                                    const APFloat &value, FloatType type) {
167a54f4eaeSMogball   arith::ConstantOp::build(builder, result, type,
168a54f4eaeSMogball                            builder.getFloatAttr(type, value));
169a54f4eaeSMogball }
170a54f4eaeSMogball 
classof(Operation * op)171a54f4eaeSMogball bool arith::ConstantFloatOp::classof(Operation *op) {
172a54f4eaeSMogball   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173a54f4eaeSMogball     return constOp.getType().isa<FloatType>();
174a54f4eaeSMogball   return false;
175a54f4eaeSMogball }
176a54f4eaeSMogball 
build(OpBuilder & builder,OperationState & result,int64_t value)177a54f4eaeSMogball void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178a54f4eaeSMogball                                    int64_t value) {
179a54f4eaeSMogball   arith::ConstantOp::build(builder, result, builder.getIndexType(),
180a54f4eaeSMogball                            builder.getIndexAttr(value));
181a54f4eaeSMogball }
182a54f4eaeSMogball 
classof(Operation * op)183a54f4eaeSMogball bool arith::ConstantIndexOp::classof(Operation *op) {
184a54f4eaeSMogball   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185a54f4eaeSMogball     return constOp.getType().isIndex();
186a54f4eaeSMogball   return false;
187a54f4eaeSMogball }
188a54f4eaeSMogball 
189a54f4eaeSMogball //===----------------------------------------------------------------------===//
1908c08f21bSMogball // AddIOp
1918c08f21bSMogball //===----------------------------------------------------------------------===//
1928c08f21bSMogball 
fold(ArrayRef<Attribute> operands)1938c08f21bSMogball OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
1948c08f21bSMogball   // addi(x, 0) -> x
195cfb72fd3SJacques Pienaar   if (matchPattern(getRhs(), m_Zero()))
196cfb72fd3SJacques Pienaar     return getLhs();
1978c08f21bSMogball 
198f278cf9cSChristian Sigg   // addi(subi(a, b), b) -> a
19921aa2a1bSWilliam S. Moses   if (auto sub = getLhs().getDefiningOp<SubIOp>())
20021aa2a1bSWilliam S. Moses     if (getRhs() == sub.getRhs())
20121aa2a1bSWilliam S. Moses       return sub.getLhs();
20221aa2a1bSWilliam S. Moses 
203f278cf9cSChristian Sigg   // addi(b, subi(a, b)) -> a
20421aa2a1bSWilliam S. Moses   if (auto sub = getRhs().getDefiningOp<SubIOp>())
20521aa2a1bSWilliam S. Moses     if (getLhs() == sub.getRhs())
20621aa2a1bSWilliam S. Moses       return sub.getLhs();
20721aa2a1bSWilliam S. Moses 
2081fc096afSMehdi Amini   return constFoldBinaryOp<IntegerAttr>(
2091fc096afSMehdi Amini       operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
2108c08f21bSMogball }
2118c08f21bSMogball 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)212b7f93c28SJeff Niu void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
213b7f93c28SJeff Niu                                                 MLIRContext *context) {
214b4e0507cSTres Popp   patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
2158c08f21bSMogball       context);
2168c08f21bSMogball }
2178c08f21bSMogball 
2188c08f21bSMogball //===----------------------------------------------------------------------===//
2198c08f21bSMogball // SubIOp
2208c08f21bSMogball //===----------------------------------------------------------------------===//
2218c08f21bSMogball 
fold(ArrayRef<Attribute> operands)2228c08f21bSMogball OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
2238c08f21bSMogball   // subi(x,x) -> 0
2248c08f21bSMogball   if (getOperand(0) == getOperand(1))
2258c08f21bSMogball     return Builder(getContext()).getZeroAttr(getType());
2268c08f21bSMogball   // subi(x,0) -> x
227cfb72fd3SJacques Pienaar   if (matchPattern(getRhs(), m_Zero()))
228cfb72fd3SJacques Pienaar     return getLhs();
2298c08f21bSMogball 
2301fc096afSMehdi Amini   return constFoldBinaryOp<IntegerAttr>(
2311fc096afSMehdi Amini       operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
2328c08f21bSMogball }
2338c08f21bSMogball 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)234b7f93c28SJeff Niu void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
235b7f93c28SJeff Niu                                                 MLIRContext *context) {
236b4e0507cSTres Popp   patterns
237b4e0507cSTres Popp       .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238b4e0507cSTres Popp            SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239b4e0507cSTres Popp           context);
2408c08f21bSMogball }
2418c08f21bSMogball 
2428c08f21bSMogball //===----------------------------------------------------------------------===//
2438c08f21bSMogball // MulIOp
2448c08f21bSMogball //===----------------------------------------------------------------------===//
2458c08f21bSMogball 
fold(ArrayRef<Attribute> operands)2468c08f21bSMogball OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
2478c08f21bSMogball   // muli(x, 0) -> 0
248cfb72fd3SJacques Pienaar   if (matchPattern(getRhs(), m_Zero()))
249cfb72fd3SJacques Pienaar     return getRhs();
2508c08f21bSMogball   // muli(x, 1) -> x
251cfb72fd3SJacques Pienaar   if (matchPattern(getRhs(), m_One()))
2528c08f21bSMogball     return getOperand(0);
2538c08f21bSMogball   // TODO: Handle the overflow case.
2548c08f21bSMogball 
2558c08f21bSMogball   // default folder
2561fc096afSMehdi Amini   return constFoldBinaryOp<IntegerAttr>(
2571fc096afSMehdi Amini       operands, [](const APInt &a, const APInt &b) { return a * b; });
2588c08f21bSMogball }
2598c08f21bSMogball 
2608c08f21bSMogball //===----------------------------------------------------------------------===//
2618c08f21bSMogball // DivUIOp
2628c08f21bSMogball //===----------------------------------------------------------------------===//
2638c08f21bSMogball 
fold(ArrayRef<Attribute> operands)2648c08f21bSMogball OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265abc17a67Sjacquesguan   // divui (x, 1) -> x.
266abc17a67Sjacquesguan   if (matchPattern(getRhs(), m_One()))
267abc17a67Sjacquesguan     return getLhs();
268abc17a67Sjacquesguan 
2698c08f21bSMogball   // Don't fold if it would require a division by zero.
2708c08f21bSMogball   bool div0 = false;
2711fc096afSMehdi Amini   auto result =
2721fc096afSMehdi Amini       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
2738c08f21bSMogball         if (div0 || !b) {
2748c08f21bSMogball           div0 = true;
2758c08f21bSMogball           return a;
2768c08f21bSMogball         }
2778c08f21bSMogball         return a.udiv(b);
2788c08f21bSMogball       });
2798c08f21bSMogball 
2808c08f21bSMogball   return div0 ? Attribute() : result;
2818c08f21bSMogball }
2828c08f21bSMogball 
2838c08f21bSMogball //===----------------------------------------------------------------------===//
2848c08f21bSMogball // DivSIOp
2858c08f21bSMogball //===----------------------------------------------------------------------===//
2868c08f21bSMogball 
fold(ArrayRef<Attribute> operands)2878c08f21bSMogball OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
288abc17a67Sjacquesguan   // divsi (x, 1) -> x.
289abc17a67Sjacquesguan   if (matchPattern(getRhs(), m_One()))
290abc17a67Sjacquesguan     return getLhs();
291abc17a67Sjacquesguan 
2928c08f21bSMogball   // Don't fold if it would overflow or if it requires a division by zero.
2938c08f21bSMogball   bool overflowOrDiv0 = false;
2941fc096afSMehdi Amini   auto result =
2951fc096afSMehdi Amini       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
2968c08f21bSMogball         if (overflowOrDiv0 || !b) {
2978c08f21bSMogball           overflowOrDiv0 = true;
2988c08f21bSMogball           return a;
2998c08f21bSMogball         }
3008c08f21bSMogball         return a.sdiv_ov(b, overflowOrDiv0);
3018c08f21bSMogball       });
3028c08f21bSMogball 
3038c08f21bSMogball   return overflowOrDiv0 ? Attribute() : result;
3048c08f21bSMogball }
3058c08f21bSMogball 
3068c08f21bSMogball //===----------------------------------------------------------------------===//
3078c08f21bSMogball // Ceil and floor division folding helpers
3088c08f21bSMogball //===----------------------------------------------------------------------===//
3098c08f21bSMogball 
signedCeilNonnegInputs(const APInt & a,const APInt & b,bool & overflow)3101fc096afSMehdi Amini static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
3111fc096afSMehdi Amini                                     bool &overflow) {
3128c08f21bSMogball   // Returns (a-1)/b + 1
3138c08f21bSMogball   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
3148c08f21bSMogball   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
3158c08f21bSMogball   return val.sadd_ov(one, overflow);
3168c08f21bSMogball }
3178c08f21bSMogball 
3188c08f21bSMogball //===----------------------------------------------------------------------===//
3198165eaa8Slipracer // CeilDivUIOp
3208165eaa8Slipracer //===----------------------------------------------------------------------===//
3218165eaa8Slipracer 
fold(ArrayRef<Attribute> operands)3228165eaa8Slipracer OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
323abc17a67Sjacquesguan   // ceildivui (x, 1) -> x.
324abc17a67Sjacquesguan   if (matchPattern(getRhs(), m_One()))
325abc17a67Sjacquesguan     return getLhs();
326abc17a67Sjacquesguan 
3278165eaa8Slipracer   bool overflowOrDiv0 = false;
3281fc096afSMehdi Amini   auto result =
3291fc096afSMehdi Amini       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
3308165eaa8Slipracer         if (overflowOrDiv0 || !b) {
3318165eaa8Slipracer           overflowOrDiv0 = true;
3328165eaa8Slipracer           return a;
3338165eaa8Slipracer         }
3348165eaa8Slipracer         APInt quotient = a.udiv(b);
3358165eaa8Slipracer         if (!a.urem(b))
3368165eaa8Slipracer           return quotient;
3378165eaa8Slipracer         APInt one(a.getBitWidth(), 1, true);
3388165eaa8Slipracer         return quotient.uadd_ov(one, overflowOrDiv0);
3398165eaa8Slipracer       });
3408165eaa8Slipracer 
3418165eaa8Slipracer   return overflowOrDiv0 ? Attribute() : result;
3428165eaa8Slipracer }
3438165eaa8Slipracer 
3448165eaa8Slipracer //===----------------------------------------------------------------------===//
3458c08f21bSMogball // CeilDivSIOp
3468c08f21bSMogball //===----------------------------------------------------------------------===//
3478c08f21bSMogball 
fold(ArrayRef<Attribute> operands)3488c08f21bSMogball OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
349abc17a67Sjacquesguan   // ceildivsi (x, 1) -> x.
350abc17a67Sjacquesguan   if (matchPattern(getRhs(), m_One()))
351abc17a67Sjacquesguan     return getLhs();
352abc17a67Sjacquesguan 
3538c08f21bSMogball   // Don't fold if it would overflow or if it requires a division by zero.
3548c08f21bSMogball   bool overflowOrDiv0 = false;
3551fc096afSMehdi Amini   auto result =
3561fc096afSMehdi Amini       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
3578c08f21bSMogball         if (overflowOrDiv0 || !b) {
3588c08f21bSMogball           overflowOrDiv0 = true;
3598c08f21bSMogball           return a;
3608c08f21bSMogball         }
361a1e62aa7SMehdi Amini         if (!a)
362a1e62aa7SMehdi Amini           return a;
363a1e62aa7SMehdi Amini         // After this point we know that neither a or b are zero.
3648c08f21bSMogball         unsigned bits = a.getBitWidth();
3658c08f21bSMogball         APInt zero = APInt::getZero(bits);
366a1e62aa7SMehdi Amini         bool aGtZero = a.sgt(zero);
367a1e62aa7SMehdi Amini         bool bGtZero = b.sgt(zero);
368a1e62aa7SMehdi Amini         if (aGtZero && bGtZero) {
3698c08f21bSMogball           // Both positive, return ceil(a, b).
3708c08f21bSMogball           return signedCeilNonnegInputs(a, b, overflowOrDiv0);
3718c08f21bSMogball         }
372a1e62aa7SMehdi Amini         if (!aGtZero && !bGtZero) {
3738c08f21bSMogball           // Both negative, return ceil(-a, -b).
3748c08f21bSMogball           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
3758c08f21bSMogball           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
3768c08f21bSMogball           return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
3778c08f21bSMogball         }
378a1e62aa7SMehdi Amini         if (!aGtZero && bGtZero) {
3798c08f21bSMogball           // A is negative, b is positive, return - ( -a / b).
3808c08f21bSMogball           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
3818c08f21bSMogball           APInt div = posA.sdiv_ov(b, overflowOrDiv0);
3828c08f21bSMogball           return zero.ssub_ov(div, overflowOrDiv0);
3838c08f21bSMogball         }
384a1e62aa7SMehdi Amini         // A is positive, b is negative, return - (a / -b).
3858c08f21bSMogball         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
3868c08f21bSMogball         APInt div = a.sdiv_ov(posB, overflowOrDiv0);
3878c08f21bSMogball         return zero.ssub_ov(div, overflowOrDiv0);
3888c08f21bSMogball       });
3898c08f21bSMogball 
3908c08f21bSMogball   return overflowOrDiv0 ? Attribute() : result;
3918c08f21bSMogball }
3928c08f21bSMogball 
3938c08f21bSMogball //===----------------------------------------------------------------------===//
3948c08f21bSMogball // FloorDivSIOp
3958c08f21bSMogball //===----------------------------------------------------------------------===//
3968c08f21bSMogball 
fold(ArrayRef<Attribute> operands)3978c08f21bSMogball OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
398abc17a67Sjacquesguan   // floordivsi (x, 1) -> x.
399abc17a67Sjacquesguan   if (matchPattern(getRhs(), m_One()))
400abc17a67Sjacquesguan     return getLhs();
401abc17a67Sjacquesguan 
4028c08f21bSMogball   // Don't fold if it would overflow or if it requires a division by zero.
4038c08f21bSMogball   bool overflowOrDiv0 = false;
4041fc096afSMehdi Amini   auto result =
4051fc096afSMehdi Amini       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
4068c08f21bSMogball         if (overflowOrDiv0 || !b) {
4078c08f21bSMogball           overflowOrDiv0 = true;
4088c08f21bSMogball           return a;
4098c08f21bSMogball         }
410a1e62aa7SMehdi Amini         if (!a)
411a1e62aa7SMehdi Amini           return a;
412a1e62aa7SMehdi Amini         // After this point we know that neither a or b are zero.
4138c08f21bSMogball         unsigned bits = a.getBitWidth();
4148c08f21bSMogball         APInt zero = APInt::getZero(bits);
415a1e62aa7SMehdi Amini         bool aGtZero = a.sgt(zero);
416a1e62aa7SMehdi Amini         bool bGtZero = b.sgt(zero);
417a1e62aa7SMehdi Amini         if (aGtZero && bGtZero) {
418a1e62aa7SMehdi Amini           // Both positive, return a / b.
4198c08f21bSMogball           return a.sdiv_ov(b, overflowOrDiv0);
4208c08f21bSMogball         }
421a1e62aa7SMehdi Amini         if (!aGtZero && !bGtZero) {
422a1e62aa7SMehdi Amini           // Both negative, return -a / -b.
4238c08f21bSMogball           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
4248c08f21bSMogball           APInt posB = zero.ssub_ov(b, overflowOrDiv0);
4258c08f21bSMogball           return posA.sdiv_ov(posB, overflowOrDiv0);
4268c08f21bSMogball         }
427a1e62aa7SMehdi Amini         if (!aGtZero && bGtZero) {
4288c08f21bSMogball           // A is negative, b is positive, return - ceil(-a, b).
4298c08f21bSMogball           APInt posA = zero.ssub_ov(a, overflowOrDiv0);
4308c08f21bSMogball           APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
4318c08f21bSMogball           return zero.ssub_ov(ceil, overflowOrDiv0);
4328c08f21bSMogball         }
4338c08f21bSMogball         // A is positive, b is negative, return - ceil(a, -b).
4348c08f21bSMogball         APInt posB = zero.ssub_ov(b, overflowOrDiv0);
4358c08f21bSMogball         APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
4368c08f21bSMogball         return zero.ssub_ov(ceil, overflowOrDiv0);
4378c08f21bSMogball       });
4388c08f21bSMogball 
4398c08f21bSMogball   return overflowOrDiv0 ? Attribute() : result;
4408c08f21bSMogball }
4418c08f21bSMogball 
4428c08f21bSMogball //===----------------------------------------------------------------------===//
4438c08f21bSMogball // RemUIOp
4448c08f21bSMogball //===----------------------------------------------------------------------===//
4458c08f21bSMogball 
fold(ArrayRef<Attribute> operands)4468c08f21bSMogball OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
4479b32886eSjacquesguan   // remui (x, 1) -> 0.
4489b32886eSjacquesguan   if (matchPattern(getRhs(), m_One()))
4499b32886eSjacquesguan     return Builder(getContext()).getZeroAttr(getType());
4508c08f21bSMogball 
4519b32886eSjacquesguan   // Don't fold if it would require a division by zero.
4529b32886eSjacquesguan   bool div0 = false;
4539b32886eSjacquesguan   auto result =
4549b32886eSjacquesguan       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
4559b32886eSjacquesguan         if (div0 || b.isNullValue()) {
4569b32886eSjacquesguan           div0 = true;
4579b32886eSjacquesguan           return a;
4589b32886eSjacquesguan         }
4599b32886eSjacquesguan         return a.urem(b);
4609b32886eSjacquesguan       });
4618c08f21bSMogball 
4629b32886eSjacquesguan   return div0 ? Attribute() : result;
4638c08f21bSMogball }
4648c08f21bSMogball 
4658c08f21bSMogball //===----------------------------------------------------------------------===//
4668c08f21bSMogball // RemSIOp
4678c08f21bSMogball //===----------------------------------------------------------------------===//
4688c08f21bSMogball 
fold(ArrayRef<Attribute> operands)4698c08f21bSMogball OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
4709b32886eSjacquesguan   // remsi (x, 1) -> 0.
4719b32886eSjacquesguan   if (matchPattern(getRhs(), m_One()))
4729b32886eSjacquesguan     return Builder(getContext()).getZeroAttr(getType());
4738c08f21bSMogball 
4749b32886eSjacquesguan   // Don't fold if it would require a division by zero.
4759b32886eSjacquesguan   bool div0 = false;
4769b32886eSjacquesguan   auto result =
4779b32886eSjacquesguan       constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
4789b32886eSjacquesguan         if (div0 || b.isNullValue()) {
4799b32886eSjacquesguan           div0 = true;
4809b32886eSjacquesguan           return a;
4819b32886eSjacquesguan         }
4829b32886eSjacquesguan         return a.srem(b);
4839b32886eSjacquesguan       });
4848c08f21bSMogball 
4859b32886eSjacquesguan   return div0 ? Attribute() : result;
4868c08f21bSMogball }
4878c08f21bSMogball 
4888c08f21bSMogball //===----------------------------------------------------------------------===//
4898c08f21bSMogball // AndIOp
4908c08f21bSMogball //===----------------------------------------------------------------------===//
4918c08f21bSMogball 
fold(ArrayRef<Attribute> operands)4928c08f21bSMogball OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
4938c08f21bSMogball   /// and(x, 0) -> 0
494cfb72fd3SJacques Pienaar   if (matchPattern(getRhs(), m_Zero()))
495cfb72fd3SJacques Pienaar     return getRhs();
4968c08f21bSMogball   /// and(x, allOnes) -> x
4978c08f21bSMogball   APInt intValue;
498cfb72fd3SJacques Pienaar   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
499cfb72fd3SJacques Pienaar     return getLhs();
5008c08f21bSMogball 
5011fc096afSMehdi Amini   return constFoldBinaryOp<IntegerAttr>(
5021fc096afSMehdi Amini       operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
5038c08f21bSMogball }
5048c08f21bSMogball 
5058c08f21bSMogball //===----------------------------------------------------------------------===//
5068c08f21bSMogball // OrIOp
5078c08f21bSMogball //===----------------------------------------------------------------------===//
5088c08f21bSMogball 
fold(ArrayRef<Attribute> operands)5098c08f21bSMogball OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
5108c08f21bSMogball   /// or(x, 0) -> x
511cfb72fd3SJacques Pienaar   if (matchPattern(getRhs(), m_Zero()))
512cfb72fd3SJacques Pienaar     return getLhs();
513a54f4eaeSMogball   /// or(x, <all ones>) -> <all ones>
514a54f4eaeSMogball   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
515a54f4eaeSMogball     if (rhsAttr.getValue().isAllOnes())
516a54f4eaeSMogball       return rhsAttr;
5178c08f21bSMogball 
5181fc096afSMehdi Amini   return constFoldBinaryOp<IntegerAttr>(
5191fc096afSMehdi Amini       operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
5208c08f21bSMogball }
5218c08f21bSMogball 
5228c08f21bSMogball //===----------------------------------------------------------------------===//
5238c08f21bSMogball // XOrIOp
5248c08f21bSMogball //===----------------------------------------------------------------------===//
5258c08f21bSMogball 
fold(ArrayRef<Attribute> operands)5268c08f21bSMogball OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
5278c08f21bSMogball   /// xor(x, 0) -> x
528cfb72fd3SJacques Pienaar   if (matchPattern(getRhs(), m_Zero()))
529cfb72fd3SJacques Pienaar     return getLhs();
5308c08f21bSMogball   /// xor(x, x) -> 0
531cfb72fd3SJacques Pienaar   if (getLhs() == getRhs())
5328c08f21bSMogball     return Builder(getContext()).getZeroAttr(getType());
53334646a2fSWilliam S. Moses   /// xor(xor(x, a), a) -> x
53434646a2fSWilliam S. Moses   if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
53534646a2fSWilliam S. Moses     if (prev.getRhs() == getRhs())
53634646a2fSWilliam S. Moses       return prev.getLhs();
5378c08f21bSMogball 
5381fc096afSMehdi Amini   return constFoldBinaryOp<IntegerAttr>(
5391fc096afSMehdi Amini       operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
5408c08f21bSMogball }
5418c08f21bSMogball 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)542b7f93c28SJeff Niu void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
543b7f93c28SJeff Niu                                                 MLIRContext *context) {
544b4e0507cSTres Popp   patterns.add<XOrINotCmpI>(context);
5458c08f21bSMogball }
5468c08f21bSMogball 
5478c08f21bSMogball //===----------------------------------------------------------------------===//
548088d3889Sjacquesguan // NegFOp
549088d3889Sjacquesguan //===----------------------------------------------------------------------===//
550088d3889Sjacquesguan 
fold(ArrayRef<Attribute> operands)551088d3889Sjacquesguan OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
5520e02bf63Sjacquesguan   /// negf(negf(x)) -> x
5530e02bf63Sjacquesguan   if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
5540e02bf63Sjacquesguan     return op.getOperand();
555088d3889Sjacquesguan   return constFoldUnaryOp<FloatAttr>(operands,
556088d3889Sjacquesguan                                      [](const APFloat &a) { return -a; });
557088d3889Sjacquesguan }
558088d3889Sjacquesguan 
559088d3889Sjacquesguan //===----------------------------------------------------------------------===//
5608c08f21bSMogball // AddFOp
5618c08f21bSMogball //===----------------------------------------------------------------------===//
5628c08f21bSMogball 
fold(ArrayRef<Attribute> operands)5638c08f21bSMogball OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
564f278cf9cSChristian Sigg   // addf(x, -0) -> x
565f278cf9cSChristian Sigg   if (matchPattern(getRhs(), m_NegZeroFloat()))
566f278cf9cSChristian Sigg     return getLhs();
567f278cf9cSChristian Sigg 
5688c08f21bSMogball   return constFoldBinaryOp<FloatAttr>(
5691fc096afSMehdi Amini       operands, [](const APFloat &a, const APFloat &b) { return a + b; });
5708c08f21bSMogball }
5718c08f21bSMogball 
5728c08f21bSMogball //===----------------------------------------------------------------------===//
5738c08f21bSMogball // SubFOp
5748c08f21bSMogball //===----------------------------------------------------------------------===//
5758c08f21bSMogball 
fold(ArrayRef<Attribute> operands)5768c08f21bSMogball OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
577f278cf9cSChristian Sigg   // subf(x, +0) -> x
578f278cf9cSChristian Sigg   if (matchPattern(getRhs(), m_PosZeroFloat()))
579f278cf9cSChristian Sigg     return getLhs();
580f278cf9cSChristian Sigg 
5818c08f21bSMogball   return constFoldBinaryOp<FloatAttr>(
5821fc096afSMehdi Amini       operands, [](const APFloat &a, const APFloat &b) { return a - b; });
5838c08f21bSMogball }
5848c08f21bSMogball 
5858c08f21bSMogball //===----------------------------------------------------------------------===//
586f278cf9cSChristian Sigg // MaxFOp
587f278cf9cSChristian Sigg //===----------------------------------------------------------------------===//
588f278cf9cSChristian Sigg 
fold(ArrayRef<Attribute> operands)589f278cf9cSChristian Sigg OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
590f278cf9cSChristian Sigg   assert(operands.size() == 2 && "maxf takes two operands");
591f278cf9cSChristian Sigg 
592f278cf9cSChristian Sigg   // maxf(x,x) -> x
593f278cf9cSChristian Sigg   if (getLhs() == getRhs())
594f278cf9cSChristian Sigg     return getRhs();
595f278cf9cSChristian Sigg 
596f278cf9cSChristian Sigg   // maxf(x, -inf) -> x
597f278cf9cSChristian Sigg   if (matchPattern(getRhs(), m_NegInfFloat()))
598f278cf9cSChristian Sigg     return getLhs();
599f278cf9cSChristian Sigg 
600f278cf9cSChristian Sigg   return constFoldBinaryOp<FloatAttr>(
601f278cf9cSChristian Sigg       operands,
602f278cf9cSChristian Sigg       [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
603f278cf9cSChristian Sigg }
604f278cf9cSChristian Sigg 
605f278cf9cSChristian Sigg //===----------------------------------------------------------------------===//
6069b1d90e8SAlexander Belyaev // MaxSIOp
6079b1d90e8SAlexander Belyaev //===----------------------------------------------------------------------===//
6089b1d90e8SAlexander Belyaev 
fold(ArrayRef<Attribute> operands)6099b1d90e8SAlexander Belyaev OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
6109b1d90e8SAlexander Belyaev   assert(operands.size() == 2 && "binary operation takes two operands");
6119b1d90e8SAlexander Belyaev 
6129b1d90e8SAlexander Belyaev   // maxsi(x,x) -> x
6139b1d90e8SAlexander Belyaev   if (getLhs() == getRhs())
6149b1d90e8SAlexander Belyaev     return getRhs();
6159b1d90e8SAlexander Belyaev 
6169b1d90e8SAlexander Belyaev   APInt intValue;
6179b1d90e8SAlexander Belyaev   // maxsi(x,MAX_INT) -> MAX_INT
6189b1d90e8SAlexander Belyaev   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
6199b1d90e8SAlexander Belyaev       intValue.isMaxSignedValue())
6209b1d90e8SAlexander Belyaev     return getRhs();
6219b1d90e8SAlexander Belyaev 
6229b1d90e8SAlexander Belyaev   // maxsi(x, MIN_INT) -> x
6239b1d90e8SAlexander Belyaev   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
6249b1d90e8SAlexander Belyaev       intValue.isMinSignedValue())
6259b1d90e8SAlexander Belyaev     return getLhs();
6269b1d90e8SAlexander Belyaev 
6271fc096afSMehdi Amini   return constFoldBinaryOp<IntegerAttr>(operands,
6281fc096afSMehdi Amini                                         [](const APInt &a, const APInt &b) {
6291fc096afSMehdi Amini                                           return llvm::APIntOps::smax(a, b);
6301fc096afSMehdi Amini                                         });
6319b1d90e8SAlexander Belyaev }
6329b1d90e8SAlexander Belyaev 
6339b1d90e8SAlexander Belyaev //===----------------------------------------------------------------------===//
6349b1d90e8SAlexander Belyaev // MaxUIOp
6359b1d90e8SAlexander Belyaev //===----------------------------------------------------------------------===//
6369b1d90e8SAlexander Belyaev 
fold(ArrayRef<Attribute> operands)6379b1d90e8SAlexander Belyaev OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
6389b1d90e8SAlexander Belyaev   assert(operands.size() == 2 && "binary operation takes two operands");
6399b1d90e8SAlexander Belyaev 
6409b1d90e8SAlexander Belyaev   // maxui(x,x) -> x
6419b1d90e8SAlexander Belyaev   if (getLhs() == getRhs())
6429b1d90e8SAlexander Belyaev     return getRhs();
6439b1d90e8SAlexander Belyaev 
6449b1d90e8SAlexander Belyaev   APInt intValue;
6459b1d90e8SAlexander Belyaev   // maxui(x,MAX_INT) -> MAX_INT
6469b1d90e8SAlexander Belyaev   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
6479b1d90e8SAlexander Belyaev     return getRhs();
6489b1d90e8SAlexander Belyaev 
6499b1d90e8SAlexander Belyaev   // maxui(x, MIN_INT) -> x
6509b1d90e8SAlexander Belyaev   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
6519b1d90e8SAlexander Belyaev     return getLhs();
6529b1d90e8SAlexander Belyaev 
6531fc096afSMehdi Amini   return constFoldBinaryOp<IntegerAttr>(operands,
6541fc096afSMehdi Amini                                         [](const APInt &a, const APInt &b) {
6551fc096afSMehdi Amini                                           return llvm::APIntOps::umax(a, b);
6561fc096afSMehdi Amini                                         });
6579b1d90e8SAlexander Belyaev }
6589b1d90e8SAlexander Belyaev 
6599b1d90e8SAlexander Belyaev //===----------------------------------------------------------------------===//
660f278cf9cSChristian Sigg // MinFOp
661f278cf9cSChristian Sigg //===----------------------------------------------------------------------===//
662f278cf9cSChristian Sigg 
fold(ArrayRef<Attribute> operands)663f278cf9cSChristian Sigg OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
664f278cf9cSChristian Sigg   assert(operands.size() == 2 && "minf takes two operands");
665f278cf9cSChristian Sigg 
666f278cf9cSChristian Sigg   // minf(x,x) -> x
667f278cf9cSChristian Sigg   if (getLhs() == getRhs())
668f278cf9cSChristian Sigg     return getRhs();
669f278cf9cSChristian Sigg 
670f278cf9cSChristian Sigg   // minf(x, +inf) -> x
671f278cf9cSChristian Sigg   if (matchPattern(getRhs(), m_PosInfFloat()))
672f278cf9cSChristian Sigg     return getLhs();
673f278cf9cSChristian Sigg 
674f278cf9cSChristian Sigg   return constFoldBinaryOp<FloatAttr>(
675f278cf9cSChristian Sigg       operands,
676f278cf9cSChristian Sigg       [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
677f278cf9cSChristian Sigg }
678f278cf9cSChristian Sigg 
679f278cf9cSChristian Sigg //===----------------------------------------------------------------------===//
6809b1d90e8SAlexander Belyaev // MinSIOp
6819b1d90e8SAlexander Belyaev //===----------------------------------------------------------------------===//
6829b1d90e8SAlexander Belyaev 
fold(ArrayRef<Attribute> operands)6839b1d90e8SAlexander Belyaev OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
6849b1d90e8SAlexander Belyaev   assert(operands.size() == 2 && "binary operation takes two operands");
6859b1d90e8SAlexander Belyaev 
6869b1d90e8SAlexander Belyaev   // minsi(x,x) -> x
6879b1d90e8SAlexander Belyaev   if (getLhs() == getRhs())
6889b1d90e8SAlexander Belyaev     return getRhs();
6899b1d90e8SAlexander Belyaev 
6909b1d90e8SAlexander Belyaev   APInt intValue;
6919b1d90e8SAlexander Belyaev   // minsi(x,MIN_INT) -> MIN_INT
6929b1d90e8SAlexander Belyaev   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
6939b1d90e8SAlexander Belyaev       intValue.isMinSignedValue())
6949b1d90e8SAlexander Belyaev     return getRhs();
6959b1d90e8SAlexander Belyaev 
6969b1d90e8SAlexander Belyaev   // minsi(x, MAX_INT) -> x
6979b1d90e8SAlexander Belyaev   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
6989b1d90e8SAlexander Belyaev       intValue.isMaxSignedValue())
6999b1d90e8SAlexander Belyaev     return getLhs();
7009b1d90e8SAlexander Belyaev 
7011fc096afSMehdi Amini   return constFoldBinaryOp<IntegerAttr>(operands,
7021fc096afSMehdi Amini                                         [](const APInt &a, const APInt &b) {
7031fc096afSMehdi Amini                                           return llvm::APIntOps::smin(a, b);
7041fc096afSMehdi Amini                                         });
7059b1d90e8SAlexander Belyaev }
7069b1d90e8SAlexander Belyaev 
7079b1d90e8SAlexander Belyaev //===----------------------------------------------------------------------===//
7089b1d90e8SAlexander Belyaev // MinUIOp
7099b1d90e8SAlexander Belyaev //===----------------------------------------------------------------------===//
7109b1d90e8SAlexander Belyaev 
fold(ArrayRef<Attribute> operands)7119b1d90e8SAlexander Belyaev OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
7129b1d90e8SAlexander Belyaev   assert(operands.size() == 2 && "binary operation takes two operands");
7139b1d90e8SAlexander Belyaev 
7149b1d90e8SAlexander Belyaev   // minui(x,x) -> x
7159b1d90e8SAlexander Belyaev   if (getLhs() == getRhs())
7169b1d90e8SAlexander Belyaev     return getRhs();
7179b1d90e8SAlexander Belyaev 
7189b1d90e8SAlexander Belyaev   APInt intValue;
7199b1d90e8SAlexander Belyaev   // minui(x,MIN_INT) -> MIN_INT
7209b1d90e8SAlexander Belyaev   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
7219b1d90e8SAlexander Belyaev     return getRhs();
7229b1d90e8SAlexander Belyaev 
7239b1d90e8SAlexander Belyaev   // minui(x, MAX_INT) -> x
7249b1d90e8SAlexander Belyaev   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
7259b1d90e8SAlexander Belyaev     return getLhs();
7269b1d90e8SAlexander Belyaev 
7271fc096afSMehdi Amini   return constFoldBinaryOp<IntegerAttr>(operands,
7281fc096afSMehdi Amini                                         [](const APInt &a, const APInt &b) {
7291fc096afSMehdi Amini                                           return llvm::APIntOps::umin(a, b);
7301fc096afSMehdi Amini                                         });
7319b1d90e8SAlexander Belyaev }
7329b1d90e8SAlexander Belyaev 
7339b1d90e8SAlexander Belyaev //===----------------------------------------------------------------------===//
7348c08f21bSMogball // MulFOp
7358c08f21bSMogball //===----------------------------------------------------------------------===//
7368c08f21bSMogball 
fold(ArrayRef<Attribute> operands)7378c08f21bSMogball OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
738f278cf9cSChristian Sigg   // mulf(x, 1) -> x
739f278cf9cSChristian Sigg   if (matchPattern(getRhs(), m_OneFloat()))
740f278cf9cSChristian Sigg     return getLhs();
741f278cf9cSChristian Sigg 
7428c08f21bSMogball   return constFoldBinaryOp<FloatAttr>(
7431fc096afSMehdi Amini       operands, [](const APFloat &a, const APFloat &b) { return a * b; });
7448c08f21bSMogball }
7458c08f21bSMogball 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)7465179f885Sjacquesguan void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
7475179f885Sjacquesguan                                                 MLIRContext *context) {
7485179f885Sjacquesguan   patterns.add<MulFOfNegF>(context);
7495179f885Sjacquesguan }
7505179f885Sjacquesguan 
7518c08f21bSMogball //===----------------------------------------------------------------------===//
7528c08f21bSMogball // DivFOp
7538c08f21bSMogball //===----------------------------------------------------------------------===//
7548c08f21bSMogball 
fold(ArrayRef<Attribute> operands)7558c08f21bSMogball OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
756f278cf9cSChristian Sigg   // divf(x, 1) -> x
757f278cf9cSChristian Sigg   if (matchPattern(getRhs(), m_OneFloat()))
758f278cf9cSChristian Sigg     return getLhs();
759f278cf9cSChristian Sigg 
7608c08f21bSMogball   return constFoldBinaryOp<FloatAttr>(
7611fc096afSMehdi Amini       operands, [](const APFloat &a, const APFloat &b) { return a / b; });
7628c08f21bSMogball }
7638c08f21bSMogball 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)7645179f885Sjacquesguan void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
7655179f885Sjacquesguan                                                 MLIRContext *context) {
7665179f885Sjacquesguan   patterns.add<DivFOfNegF>(context);
7675179f885Sjacquesguan }
7685179f885Sjacquesguan 
7698c08f21bSMogball //===----------------------------------------------------------------------===//
77019e28547Sjacquesguan // RemFOp
77119e28547Sjacquesguan //===----------------------------------------------------------------------===//
77219e28547Sjacquesguan 
fold(ArrayRef<Attribute> operands)77319e28547Sjacquesguan OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) {
77419e28547Sjacquesguan   return constFoldBinaryOp<FloatAttr>(operands,
77519e28547Sjacquesguan                                       [](const APFloat &a, const APFloat &b) {
776c2828b63SMehdi Amini                                         APFloat result(a);
777c2828b63SMehdi Amini                                         (void)result.remainder(b);
778c2828b63SMehdi Amini                                         return result;
77919e28547Sjacquesguan                                       });
78019e28547Sjacquesguan }
78119e28547Sjacquesguan 
78219e28547Sjacquesguan //===----------------------------------------------------------------------===//
783a54f4eaeSMogball // Utility functions for verifying cast ops
784a54f4eaeSMogball //===----------------------------------------------------------------------===//
785a54f4eaeSMogball 
786a54f4eaeSMogball template <typename... Types>
787a54f4eaeSMogball using type_list = std::tuple<Types...> *;
788a54f4eaeSMogball 
789a54f4eaeSMogball /// Returns a non-null type only if the provided type is one of the allowed
790a54f4eaeSMogball /// types or one of the allowed shaped types of the allowed types. Returns the
791a54f4eaeSMogball /// element type if a valid shaped type is provided.
792a54f4eaeSMogball template <typename... ShapedTypes, typename... ElementTypes>
getUnderlyingType(Type type,type_list<ShapedTypes...>,type_list<ElementTypes...>)793a54f4eaeSMogball static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
794a54f4eaeSMogball                               type_list<ElementTypes...>) {
795a54f4eaeSMogball   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
796a54f4eaeSMogball     return {};
797a54f4eaeSMogball 
798a54f4eaeSMogball   auto underlyingType = getElementTypeOrSelf(type);
799a54f4eaeSMogball   if (!underlyingType.isa<ElementTypes...>())
800a54f4eaeSMogball     return {};
801a54f4eaeSMogball 
802a54f4eaeSMogball   return underlyingType;
803a54f4eaeSMogball }
804a54f4eaeSMogball 
805a54f4eaeSMogball /// Get allowed underlying types for vectors and tensors.
806a54f4eaeSMogball template <typename... ElementTypes>
getTypeIfLike(Type type)807a54f4eaeSMogball static Type getTypeIfLike(Type type) {
808a54f4eaeSMogball   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
809a54f4eaeSMogball                            type_list<ElementTypes...>());
810a54f4eaeSMogball }
811a54f4eaeSMogball 
812a54f4eaeSMogball /// Get allowed underlying types for vectors, tensors, and memrefs.
813a54f4eaeSMogball template <typename... ElementTypes>
getTypeIfLikeOrMemRef(Type type)814a54f4eaeSMogball static Type getTypeIfLikeOrMemRef(Type type) {
815a54f4eaeSMogball   return getUnderlyingType(type,
816a54f4eaeSMogball                            type_list<VectorType, TensorType, MemRefType>(),
817a54f4eaeSMogball                            type_list<ElementTypes...>());
818a54f4eaeSMogball }
819a54f4eaeSMogball 
areValidCastInputsAndOutputs(TypeRange inputs,TypeRange outputs)820a54f4eaeSMogball static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
821a54f4eaeSMogball   return inputs.size() == 1 && outputs.size() == 1 &&
822a54f4eaeSMogball          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
823a54f4eaeSMogball }
824a54f4eaeSMogball 
825a54f4eaeSMogball //===----------------------------------------------------------------------===//
8268c08f21bSMogball // Verifiers for integer and floating point extension/truncation ops
8278c08f21bSMogball //===----------------------------------------------------------------------===//
8288c08f21bSMogball 
8298c08f21bSMogball // Extend ops can only extend to a wider type.
8308c08f21bSMogball template <typename ValType, typename Op>
verifyExtOp(Op op)8318c08f21bSMogball static LogicalResult verifyExtOp(Op op) {
832cfb72fd3SJacques Pienaar   Type srcType = getElementTypeOrSelf(op.getIn().getType());
8338c08f21bSMogball   Type dstType = getElementTypeOrSelf(op.getType());
8348c08f21bSMogball 
8358c08f21bSMogball   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
8368c08f21bSMogball     return op.emitError("result type ")
8378c08f21bSMogball            << dstType << " must be wider than operand type " << srcType;
8388c08f21bSMogball 
8398c08f21bSMogball   return success();
8408c08f21bSMogball }
8418c08f21bSMogball 
8428c08f21bSMogball // Truncate ops can only truncate to a shorter type.
8438c08f21bSMogball template <typename ValType, typename Op>
verifyTruncateOp(Op op)8448c08f21bSMogball static LogicalResult verifyTruncateOp(Op op) {
845cfb72fd3SJacques Pienaar   Type srcType = getElementTypeOrSelf(op.getIn().getType());
8468c08f21bSMogball   Type dstType = getElementTypeOrSelf(op.getType());
8478c08f21bSMogball 
8488c08f21bSMogball   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
8498c08f21bSMogball     return op.emitError("result type ")
8508c08f21bSMogball            << dstType << " must be shorter than operand type " << srcType;
8518c08f21bSMogball 
8528c08f21bSMogball   return success();
8538c08f21bSMogball }
8548c08f21bSMogball 
855a54f4eaeSMogball /// Validate a cast that changes the width of a type.
856a54f4eaeSMogball template <template <typename> class WidthComparator, typename... ElementTypes>
checkWidthChangeCast(TypeRange inputs,TypeRange outputs)857a54f4eaeSMogball static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
858a54f4eaeSMogball   if (!areValidCastInputsAndOutputs(inputs, outputs))
859a54f4eaeSMogball     return false;
860a54f4eaeSMogball 
861a54f4eaeSMogball   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
862a54f4eaeSMogball   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
863a54f4eaeSMogball   if (!srcType || !dstType)
864a54f4eaeSMogball     return false;
865a54f4eaeSMogball 
866a54f4eaeSMogball   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
867a54f4eaeSMogball                                      srcType.getIntOrFloatBitWidth());
868a54f4eaeSMogball }
869a54f4eaeSMogball 
8708c08f21bSMogball //===----------------------------------------------------------------------===//
8718c08f21bSMogball // ExtUIOp
8728c08f21bSMogball //===----------------------------------------------------------------------===//
8738c08f21bSMogball 
fold(ArrayRef<Attribute> operands)8748c08f21bSMogball OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
8751bb9f4e4SWilliam S. Moses   if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
8761bb9f4e4SWilliam S. Moses     getInMutable().assign(lhs.getIn());
8771bb9f4e4SWilliam S. Moses     return getResult();
8781bb9f4e4SWilliam S. Moses   }
879605fc89aSjacquesguan   Type resType = getType();
880605fc89aSjacquesguan   unsigned bitWidth;
881605fc89aSjacquesguan   if (auto shapedType = resType.dyn_cast<ShapedType>())
882605fc89aSjacquesguan     bitWidth = shapedType.getElementTypeBitWidth();
883605fc89aSjacquesguan   else
884605fc89aSjacquesguan     bitWidth = resType.getIntOrFloatBitWidth();
885605fc89aSjacquesguan   return constFoldCastOp<IntegerAttr, IntegerAttr>(
886605fc89aSjacquesguan       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
887605fc89aSjacquesguan         return a.zext(bitWidth);
888605fc89aSjacquesguan       });
8898c08f21bSMogball }
8908c08f21bSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)891a54f4eaeSMogball bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
892a54f4eaeSMogball   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
893a54f4eaeSMogball }
894a54f4eaeSMogball 
verify()8951be88f5aSRiver Riddle LogicalResult arith::ExtUIOp::verify() {
8961be88f5aSRiver Riddle   return verifyExtOp<IntegerType>(*this);
8971be88f5aSRiver Riddle }
8981be88f5aSRiver Riddle 
8998c08f21bSMogball //===----------------------------------------------------------------------===//
9008c08f21bSMogball // ExtSIOp
9018c08f21bSMogball //===----------------------------------------------------------------------===//
9028c08f21bSMogball 
fold(ArrayRef<Attribute> operands)9038c08f21bSMogball OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
9041bb9f4e4SWilliam S. Moses   if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
9051bb9f4e4SWilliam S. Moses     getInMutable().assign(lhs.getIn());
9061bb9f4e4SWilliam S. Moses     return getResult();
9071bb9f4e4SWilliam S. Moses   }
908605fc89aSjacquesguan   Type resType = getType();
909605fc89aSjacquesguan   unsigned bitWidth;
910605fc89aSjacquesguan   if (auto shapedType = resType.dyn_cast<ShapedType>())
911605fc89aSjacquesguan     bitWidth = shapedType.getElementTypeBitWidth();
912605fc89aSjacquesguan   else
913605fc89aSjacquesguan     bitWidth = resType.getIntOrFloatBitWidth();
914605fc89aSjacquesguan   return constFoldCastOp<IntegerAttr, IntegerAttr>(
915605fc89aSjacquesguan       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
916605fc89aSjacquesguan         return a.sext(bitWidth);
917605fc89aSjacquesguan       });
9188c08f21bSMogball }
9198c08f21bSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)920a54f4eaeSMogball bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
921a54f4eaeSMogball   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
9228f0c673dSMogball }
9238f0c673dSMogball 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)924b7f93c28SJeff Niu void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
925b7f93c28SJeff Niu                                                  MLIRContext *context) {
926b4e0507cSTres Popp   patterns.add<ExtSIOfExtUI>(context);
9271bb9f4e4SWilliam S. Moses }
9281bb9f4e4SWilliam S. Moses 
verify()9291be88f5aSRiver Riddle LogicalResult arith::ExtSIOp::verify() {
9301be88f5aSRiver Riddle   return verifyExtOp<IntegerType>(*this);
9311be88f5aSRiver Riddle }
9321be88f5aSRiver Riddle 
933a54f4eaeSMogball //===----------------------------------------------------------------------===//
934a54f4eaeSMogball // ExtFOp
935a54f4eaeSMogball //===----------------------------------------------------------------------===//
936a54f4eaeSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)937a54f4eaeSMogball bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
938a54f4eaeSMogball   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
9398f0c673dSMogball }
9408f0c673dSMogball 
verify()9411be88f5aSRiver Riddle LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
9421be88f5aSRiver Riddle 
943a54f4eaeSMogball //===----------------------------------------------------------------------===//
944a54f4eaeSMogball // TruncIOp
945a54f4eaeSMogball //===----------------------------------------------------------------------===//
946a54f4eaeSMogball 
fold(ArrayRef<Attribute> operands)9478f0c673dSMogball OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
94834646a2fSWilliam S. Moses   assert(operands.size() == 1 && "unary operation takes one operand");
94934646a2fSWilliam S. Moses 
950a54f4eaeSMogball   // trunci(zexti(a)) -> a
951a54f4eaeSMogball   // trunci(sexti(a)) -> a
952a54f4eaeSMogball   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
953a54f4eaeSMogball       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
954a54f4eaeSMogball     return getOperand().getDefiningOp()->getOperand(0);
955a54f4eaeSMogball 
95634646a2fSWilliam S. Moses   // trunci(trunci(a)) -> trunci(a))
95734646a2fSWilliam S. Moses   if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
95834646a2fSWilliam S. Moses     setOperand(getOperand().getDefiningOp()->getOperand(0));
95934646a2fSWilliam S. Moses     return getResult();
96034646a2fSWilliam S. Moses   }
961a54f4eaeSMogball 
962605fc89aSjacquesguan   Type resType = getType();
963605fc89aSjacquesguan   unsigned bitWidth;
964605fc89aSjacquesguan   if (auto shapedType = resType.dyn_cast<ShapedType>())
965605fc89aSjacquesguan     bitWidth = shapedType.getElementTypeBitWidth();
966605fc89aSjacquesguan   else
967605fc89aSjacquesguan     bitWidth = resType.getIntOrFloatBitWidth();
968a54f4eaeSMogball 
969605fc89aSjacquesguan   return constFoldCastOp<IntegerAttr, IntegerAttr>(
970605fc89aSjacquesguan       operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
971605fc89aSjacquesguan         return a.trunc(bitWidth);
972605fc89aSjacquesguan       });
9738f0c673dSMogball }
9748f0c673dSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)9758f0c673dSMogball bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
976a54f4eaeSMogball   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
9778f0c673dSMogball }
9788f0c673dSMogball 
verify()9791be88f5aSRiver Riddle LogicalResult arith::TruncIOp::verify() {
9801be88f5aSRiver Riddle   return verifyTruncateOp<IntegerType>(*this);
9811be88f5aSRiver Riddle }
9821be88f5aSRiver Riddle 
983a54f4eaeSMogball //===----------------------------------------------------------------------===//
984a54f4eaeSMogball // TruncFOp
985a54f4eaeSMogball //===----------------------------------------------------------------------===//
9868f0c673dSMogball 
987a54f4eaeSMogball /// Perform safe const propagation for truncf, i.e. only propagate if FP value
988a54f4eaeSMogball /// can be represented without precision loss or rounding.
fold(ArrayRef<Attribute> operands)989a54f4eaeSMogball OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
990a54f4eaeSMogball   assert(operands.size() == 1 && "unary operation takes one operand");
9918f0c673dSMogball 
992a54f4eaeSMogball   auto constOperand = operands.front();
993a54f4eaeSMogball   if (!constOperand || !constOperand.isa<FloatAttr>())
994a54f4eaeSMogball     return {};
9958f0c673dSMogball 
996a54f4eaeSMogball   // Convert to target type via 'double'.
997a54f4eaeSMogball   double sourceValue =
998a54f4eaeSMogball       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
999a54f4eaeSMogball   auto targetAttr = FloatAttr::get(getType(), sourceValue);
1000a54f4eaeSMogball 
1001a54f4eaeSMogball   // Propagate if constant's value does not change after truncation.
1002a54f4eaeSMogball   if (sourceValue == targetAttr.getValue().convertToDouble())
1003a54f4eaeSMogball     return targetAttr;
1004a54f4eaeSMogball 
10058f0c673dSMogball   return {};
10068f0c673dSMogball }
10078f0c673dSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)1008a54f4eaeSMogball bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1009a54f4eaeSMogball   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
10108f0c673dSMogball }
10118f0c673dSMogball 
verify()10121be88f5aSRiver Riddle LogicalResult arith::TruncFOp::verify() {
10131be88f5aSRiver Riddle   return verifyTruncateOp<FloatType>(*this);
10141be88f5aSRiver Riddle }
10151be88f5aSRiver Riddle 
1016a54f4eaeSMogball //===----------------------------------------------------------------------===//
1017834cf3beSWilliam S. Moses // AndIOp
1018834cf3beSWilliam S. Moses //===----------------------------------------------------------------------===//
1019834cf3beSWilliam S. Moses 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1020b7f93c28SJeff Niu void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1021b7f93c28SJeff Niu                                                 MLIRContext *context) {
1022b4e0507cSTres Popp   patterns.add<AndOfExtUI, AndOfExtSI>(context);
1023834cf3beSWilliam S. Moses }
1024834cf3beSWilliam S. Moses 
1025834cf3beSWilliam S. Moses //===----------------------------------------------------------------------===//
1026834cf3beSWilliam S. Moses // OrIOp
1027834cf3beSWilliam S. Moses //===----------------------------------------------------------------------===//
1028834cf3beSWilliam S. Moses 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1029b7f93c28SJeff Niu void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1030b7f93c28SJeff Niu                                                MLIRContext *context) {
1031b4e0507cSTres Popp   patterns.add<OrOfExtUI, OrOfExtSI>(context);
1032834cf3beSWilliam S. Moses }
1033834cf3beSWilliam S. Moses 
1034834cf3beSWilliam S. Moses //===----------------------------------------------------------------------===//
1035a54f4eaeSMogball // Verifiers for casts between integers and floats.
1036a54f4eaeSMogball //===----------------------------------------------------------------------===//
1037a54f4eaeSMogball 
1038a54f4eaeSMogball template <typename From, typename To>
checkIntFloatCast(TypeRange inputs,TypeRange outputs)1039a54f4eaeSMogball static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1040a54f4eaeSMogball   if (!areValidCastInputsAndOutputs(inputs, outputs))
1041a54f4eaeSMogball     return false;
1042a54f4eaeSMogball 
1043a54f4eaeSMogball   auto srcType = getTypeIfLike<From>(inputs.front());
1044a54f4eaeSMogball   auto dstType = getTypeIfLike<To>(outputs.back());
1045a54f4eaeSMogball 
1046a54f4eaeSMogball   return srcType && dstType;
1047a54f4eaeSMogball }
1048a54f4eaeSMogball 
1049a54f4eaeSMogball //===----------------------------------------------------------------------===//
1050a54f4eaeSMogball // UIToFPOp
1051a54f4eaeSMogball //===----------------------------------------------------------------------===//
1052a54f4eaeSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)10538f0c673dSMogball bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1054a54f4eaeSMogball   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
10558f0c673dSMogball }
10568f0c673dSMogball 
fold(ArrayRef<Attribute> operands)1057ca8997ebSWilliam S. Moses OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1058605fc89aSjacquesguan   Type resType = getType();
1059605fc89aSjacquesguan   Type resEleType;
1060605fc89aSjacquesguan   if (auto shapedType = resType.dyn_cast<ShapedType>())
1061605fc89aSjacquesguan     resEleType = shapedType.getElementType();
1062605fc89aSjacquesguan   else
1063605fc89aSjacquesguan     resEleType = resType;
1064605fc89aSjacquesguan   return constFoldCastOp<IntegerAttr, FloatAttr>(
1065605fc89aSjacquesguan       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1066605fc89aSjacquesguan         FloatType floatTy = resEleType.cast<FloatType>();
1067ca8997ebSWilliam S. Moses         APFloat apf(floatTy.getFloatSemantics(),
1068ca8997ebSWilliam S. Moses                     APInt::getZero(floatTy.getWidth()));
1069605fc89aSjacquesguan         apf.convertFromAPInt(a, /*IsSigned=*/false,
1070605fc89aSjacquesguan                              APFloat::rmNearestTiesToEven);
1071605fc89aSjacquesguan         return apf;
1072605fc89aSjacquesguan       });
1073ca8997ebSWilliam S. Moses }
1074ca8997ebSWilliam S. Moses 
1075a54f4eaeSMogball //===----------------------------------------------------------------------===//
1076a54f4eaeSMogball // SIToFPOp
1077a54f4eaeSMogball //===----------------------------------------------------------------------===//
1078a54f4eaeSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)1079a54f4eaeSMogball bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1080a54f4eaeSMogball   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
10818f0c673dSMogball }
10828f0c673dSMogball 
fold(ArrayRef<Attribute> operands)1083ca8997ebSWilliam S. Moses OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1084605fc89aSjacquesguan   Type resType = getType();
1085605fc89aSjacquesguan   Type resEleType;
1086605fc89aSjacquesguan   if (auto shapedType = resType.dyn_cast<ShapedType>())
1087605fc89aSjacquesguan     resEleType = shapedType.getElementType();
1088605fc89aSjacquesguan   else
1089605fc89aSjacquesguan     resEleType = resType;
1090605fc89aSjacquesguan   return constFoldCastOp<IntegerAttr, FloatAttr>(
1091605fc89aSjacquesguan       operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1092605fc89aSjacquesguan         FloatType floatTy = resEleType.cast<FloatType>();
1093ca8997ebSWilliam S. Moses         APFloat apf(floatTy.getFloatSemantics(),
1094ca8997ebSWilliam S. Moses                     APInt::getZero(floatTy.getWidth()));
1095605fc89aSjacquesguan         apf.convertFromAPInt(a, /*IsSigned=*/true,
1096605fc89aSjacquesguan                              APFloat::rmNearestTiesToEven);
1097605fc89aSjacquesguan         return apf;
1098605fc89aSjacquesguan       });
1099ca8997ebSWilliam S. Moses }
1100a54f4eaeSMogball //===----------------------------------------------------------------------===//
1101a54f4eaeSMogball // FPToUIOp
1102a54f4eaeSMogball //===----------------------------------------------------------------------===//
1103a54f4eaeSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)11048f0c673dSMogball bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1105a54f4eaeSMogball   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1106a54f4eaeSMogball }
1107a54f4eaeSMogball 
fold(ArrayRef<Attribute> operands)1108ca8997ebSWilliam S. Moses OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1109605fc89aSjacquesguan   Type resType = getType();
1110605fc89aSjacquesguan   Type resEleType;
1111605fc89aSjacquesguan   if (auto shapedType = resType.dyn_cast<ShapedType>())
1112605fc89aSjacquesguan     resEleType = shapedType.getElementType();
1113605fc89aSjacquesguan   else
1114605fc89aSjacquesguan     resEleType = resType;
1115605fc89aSjacquesguan   return constFoldCastOp<FloatAttr, IntegerAttr>(
1116605fc89aSjacquesguan       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1117605fc89aSjacquesguan         IntegerType intTy = resEleType.cast<IntegerType>();
1118ca8997ebSWilliam S. Moses         bool ignored;
11195caee217SMehdi Amini         APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1120605fc89aSjacquesguan         castStatus = APFloat::opInvalidOp !=
1121605fc89aSjacquesguan                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1122605fc89aSjacquesguan         return api;
1123605fc89aSjacquesguan       });
1124ca8997ebSWilliam S. Moses }
1125ca8997ebSWilliam S. Moses 
1126a54f4eaeSMogball //===----------------------------------------------------------------------===//
1127a54f4eaeSMogball // FPToSIOp
1128a54f4eaeSMogball //===----------------------------------------------------------------------===//
1129a54f4eaeSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)1130a54f4eaeSMogball bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1131a54f4eaeSMogball   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
11328f0c673dSMogball }
11338f0c673dSMogball 
fold(ArrayRef<Attribute> operands)1134ca8997ebSWilliam S. Moses OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1135605fc89aSjacquesguan   Type resType = getType();
1136605fc89aSjacquesguan   Type resEleType;
1137605fc89aSjacquesguan   if (auto shapedType = resType.dyn_cast<ShapedType>())
1138605fc89aSjacquesguan     resEleType = shapedType.getElementType();
1139605fc89aSjacquesguan   else
1140605fc89aSjacquesguan     resEleType = resType;
1141605fc89aSjacquesguan   return constFoldCastOp<FloatAttr, IntegerAttr>(
1142605fc89aSjacquesguan       operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1143605fc89aSjacquesguan         IntegerType intTy = resEleType.cast<IntegerType>();
1144ca8997ebSWilliam S. Moses         bool ignored;
11455caee217SMehdi Amini         APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1146605fc89aSjacquesguan         castStatus = APFloat::opInvalidOp !=
1147605fc89aSjacquesguan                      a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1148605fc89aSjacquesguan         return api;
1149605fc89aSjacquesguan       });
1150ca8997ebSWilliam S. Moses }
1151ca8997ebSWilliam S. Moses 
11528c08f21bSMogball //===----------------------------------------------------------------------===//
11538c08f21bSMogball // IndexCastOp
11548c08f21bSMogball //===----------------------------------------------------------------------===//
11558c08f21bSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)11568c08f21bSMogball bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
11578c08f21bSMogball                                            TypeRange outputs) {
1158a54f4eaeSMogball   if (!areValidCastInputsAndOutputs(inputs, outputs))
1159a54f4eaeSMogball     return false;
11608c08f21bSMogball 
1161a54f4eaeSMogball   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1162a54f4eaeSMogball   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1163a54f4eaeSMogball   if (!srcType || !dstType)
1164a54f4eaeSMogball     return false;
11658c08f21bSMogball 
11668c08f21bSMogball   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
11678c08f21bSMogball          (srcType.isSignlessInteger() && dstType.isIndex());
11688c08f21bSMogball }
11698c08f21bSMogball 
fold(ArrayRef<Attribute> operands)11708c08f21bSMogball OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
11718c08f21bSMogball   // index_cast(constant) -> constant
11728c08f21bSMogball   // A little hack because we go through int. Otherwise, the size of the
11738c08f21bSMogball   // constant might need to change.
11748c08f21bSMogball   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
11758c08f21bSMogball     return IntegerAttr::get(getType(), value.getInt());
11768c08f21bSMogball 
11778c08f21bSMogball   return {};
11788c08f21bSMogball }
11798c08f21bSMogball 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)11808c08f21bSMogball void arith::IndexCastOp::getCanonicalizationPatterns(
11819f85c198SRiver Riddle     RewritePatternSet &patterns, MLIRContext *context) {
1182b4e0507cSTres Popp   patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
11838c08f21bSMogball }
11848c08f21bSMogball 
11858c08f21bSMogball //===----------------------------------------------------------------------===//
11868c08f21bSMogball // BitcastOp
11878c08f21bSMogball //===----------------------------------------------------------------------===//
11888c08f21bSMogball 
areCastCompatible(TypeRange inputs,TypeRange outputs)11898c08f21bSMogball bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1190a54f4eaeSMogball   if (!areValidCastInputsAndOutputs(inputs, outputs))
1191a54f4eaeSMogball     return false;
11928c08f21bSMogball 
1193a54f4eaeSMogball   auto srcType =
1194a54f4eaeSMogball       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1195a54f4eaeSMogball   auto dstType =
1196a54f4eaeSMogball       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1197a54f4eaeSMogball   if (!srcType || !dstType)
1198a54f4eaeSMogball     return false;
11998c08f21bSMogball 
12008c08f21bSMogball   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
12018c08f21bSMogball }
12028c08f21bSMogball 
fold(ArrayRef<Attribute> operands)12038c08f21bSMogball OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
12048c08f21bSMogball   assert(operands.size() == 1 && "bitcast op expects 1 operand");
12058c08f21bSMogball 
12068c08f21bSMogball   auto resType = getType();
12078c08f21bSMogball   auto operand = operands[0];
12088c08f21bSMogball   if (!operand)
12098c08f21bSMogball     return {};
12108c08f21bSMogball 
12118c08f21bSMogball   /// Bitcast dense elements.
12128c08f21bSMogball   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
12138c08f21bSMogball     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
12148c08f21bSMogball   /// Other shaped types unhandled.
12158c08f21bSMogball   if (resType.isa<ShapedType>())
12168c08f21bSMogball     return {};
12178c08f21bSMogball 
12188c08f21bSMogball   /// Bitcast integer or float to integer or float.
12198c08f21bSMogball   APInt bits = operand.isa<FloatAttr>()
12208c08f21bSMogball                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
12218c08f21bSMogball                    : operand.cast<IntegerAttr>().getValue();
12228c08f21bSMogball 
12238c08f21bSMogball   if (auto resFloatType = resType.dyn_cast<FloatType>())
12248c08f21bSMogball     return FloatAttr::get(resType,
12258c08f21bSMogball                           APFloat(resFloatType.getFloatSemantics(), bits));
12268c08f21bSMogball   return IntegerAttr::get(resType, bits);
12278c08f21bSMogball }
12288c08f21bSMogball 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1229b7f93c28SJeff Niu void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1230b7f93c28SJeff Niu                                                    MLIRContext *context) {
1231b4e0507cSTres Popp   patterns.add<BitcastOfBitcast>(context);
12328c08f21bSMogball }
12338c08f21bSMogball 
12348c08f21bSMogball //===----------------------------------------------------------------------===//
12358c08f21bSMogball // Helpers for compare ops
12368c08f21bSMogball //===----------------------------------------------------------------------===//
12378c08f21bSMogball 
12388c08f21bSMogball /// Return the type of the same shape (scalar, vector or tensor) containing i1.
getI1SameShape(Type type)12398c08f21bSMogball static Type getI1SameShape(Type type) {
12408c08f21bSMogball   auto i1Type = IntegerType::get(type.getContext(), 1);
12418c08f21bSMogball   if (auto tensorType = type.dyn_cast<RankedTensorType>())
12428c08f21bSMogball     return RankedTensorType::get(tensorType.getShape(), i1Type);
12438c08f21bSMogball   if (type.isa<UnrankedTensorType>())
12448c08f21bSMogball     return UnrankedTensorType::get(i1Type);
12458c08f21bSMogball   if (auto vectorType = type.dyn_cast<VectorType>())
1246a4830d14SJavier Setoain     return VectorType::get(vectorType.getShape(), i1Type,
1247a4830d14SJavier Setoain                            vectorType.getNumScalableDims());
12488c08f21bSMogball   return i1Type;
12498c08f21bSMogball }
12508c08f21bSMogball 
12518c08f21bSMogball //===----------------------------------------------------------------------===//
12528c08f21bSMogball // CmpIOp
12538c08f21bSMogball //===----------------------------------------------------------------------===//
12548c08f21bSMogball 
12558c08f21bSMogball /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
12568c08f21bSMogball /// comparison predicates.
applyCmpPredicate(arith::CmpIPredicate predicate,const APInt & lhs,const APInt & rhs)12578c08f21bSMogball bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
12588c08f21bSMogball                                     const APInt &lhs, const APInt &rhs) {
12598c08f21bSMogball   switch (predicate) {
12608c08f21bSMogball   case arith::CmpIPredicate::eq:
12618c08f21bSMogball     return lhs.eq(rhs);
12628c08f21bSMogball   case arith::CmpIPredicate::ne:
12638c08f21bSMogball     return lhs.ne(rhs);
12648c08f21bSMogball   case arith::CmpIPredicate::slt:
12658c08f21bSMogball     return lhs.slt(rhs);
12668c08f21bSMogball   case arith::CmpIPredicate::sle:
12678c08f21bSMogball     return lhs.sle(rhs);
12688c08f21bSMogball   case arith::CmpIPredicate::sgt:
12698c08f21bSMogball     return lhs.sgt(rhs);
12708c08f21bSMogball   case arith::CmpIPredicate::sge:
12718c08f21bSMogball     return lhs.sge(rhs);
12728c08f21bSMogball   case arith::CmpIPredicate::ult:
12738c08f21bSMogball     return lhs.ult(rhs);
12748c08f21bSMogball   case arith::CmpIPredicate::ule:
12758c08f21bSMogball     return lhs.ule(rhs);
12768c08f21bSMogball   case arith::CmpIPredicate::ugt:
12778c08f21bSMogball     return lhs.ugt(rhs);
12788c08f21bSMogball   case arith::CmpIPredicate::uge:
12798c08f21bSMogball     return lhs.uge(rhs);
12808c08f21bSMogball   }
12818c08f21bSMogball   llvm_unreachable("unknown cmpi predicate kind");
12828c08f21bSMogball }
12838c08f21bSMogball 
12848c08f21bSMogball /// Returns true if the predicate is true for two equal operands.
applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)12858c08f21bSMogball static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
12868c08f21bSMogball   switch (predicate) {
12878c08f21bSMogball   case arith::CmpIPredicate::eq:
12888c08f21bSMogball   case arith::CmpIPredicate::sle:
12898c08f21bSMogball   case arith::CmpIPredicate::sge:
12908c08f21bSMogball   case arith::CmpIPredicate::ule:
12918c08f21bSMogball   case arith::CmpIPredicate::uge:
12928c08f21bSMogball     return true;
12938c08f21bSMogball   case arith::CmpIPredicate::ne:
12948c08f21bSMogball   case arith::CmpIPredicate::slt:
12958c08f21bSMogball   case arith::CmpIPredicate::sgt:
12968c08f21bSMogball   case arith::CmpIPredicate::ult:
12978c08f21bSMogball   case arith::CmpIPredicate::ugt:
12988c08f21bSMogball     return false;
12998c08f21bSMogball   }
13008c08f21bSMogball   llvm_unreachable("unknown cmpi predicate kind");
13018c08f21bSMogball }
13028c08f21bSMogball 
getBoolAttribute(Type type,MLIRContext * ctx,bool value)13034a10457dSAdrian Kuegel static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
13044a10457dSAdrian Kuegel   auto boolAttr = BoolAttr::get(ctx, value);
13054a10457dSAdrian Kuegel   ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
13064a10457dSAdrian Kuegel   if (!shapedType)
13074a10457dSAdrian Kuegel     return boolAttr;
13084a10457dSAdrian Kuegel   return DenseElementsAttr::get(shapedType, boolAttr);
13094a10457dSAdrian Kuegel }
13104a10457dSAdrian Kuegel 
fold(ArrayRef<Attribute> operands)13118c08f21bSMogball OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
13128c08f21bSMogball   assert(operands.size() == 2 && "cmpi takes two operands");
13138c08f21bSMogball 
13148c08f21bSMogball   // cmpi(pred, x, x)
1315cfb72fd3SJacques Pienaar   if (getLhs() == getRhs()) {
13168c08f21bSMogball     auto val = applyCmpPredicateToEqualOperands(getPredicate());
13174a10457dSAdrian Kuegel     return getBoolAttribute(getType(), getContext(), val);
13188c08f21bSMogball   }
13198c08f21bSMogball 
13201a0a1779SWilliam S. Moses   if (matchPattern(getRhs(), m_Zero())) {
13211a0a1779SWilliam S. Moses     if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
13221a0a1779SWilliam S. Moses       // extsi(%x : i1 -> iN) != 0  ->  %x
132310c9ecceSjacquesguan       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
132410c9ecceSjacquesguan           getPredicate() == arith::CmpIPredicate::ne)
13251a0a1779SWilliam S. Moses         return extOp.getOperand();
13261a0a1779SWilliam S. Moses     }
13271a0a1779SWilliam S. Moses     if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
13281a0a1779SWilliam S. Moses       // extui(%x : i1 -> iN) != 0  ->  %x
132910c9ecceSjacquesguan       if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
133010c9ecceSjacquesguan           getPredicate() == arith::CmpIPredicate::ne)
13311a0a1779SWilliam S. Moses         return extOp.getOperand();
13321a0a1779SWilliam S. Moses     }
13331a0a1779SWilliam S. Moses   }
13341a0a1779SWilliam S. Moses 
1335917e4519SIvan Butygin   // Move constant to the right side.
1336917e4519SIvan Butygin   if (operands[0] && !operands[1]) {
1337917e4519SIvan Butygin     // Do not use invertPredicate, as it will change eq to ne and vice versa.
1338917e4519SIvan Butygin     using Pred = CmpIPredicate;
1339917e4519SIvan Butygin     const std::pair<Pred, Pred> invPreds[] = {
1340917e4519SIvan Butygin         {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1341917e4519SIvan Butygin         {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1342917e4519SIvan Butygin         {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1343917e4519SIvan Butygin         {Pred::ne, Pred::ne},
1344917e4519SIvan Butygin     };
1345917e4519SIvan Butygin     Pred origPred = getPredicate();
1346917e4519SIvan Butygin     for (auto pred : invPreds) {
1347917e4519SIvan Butygin       if (origPred == pred.first) {
1348917e4519SIvan Butygin         setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second));
1349917e4519SIvan Butygin         Value lhs = getLhs();
1350917e4519SIvan Butygin         Value rhs = getRhs();
1351917e4519SIvan Butygin         getLhsMutable().assign(rhs);
1352917e4519SIvan Butygin         getRhsMutable().assign(lhs);
1353917e4519SIvan Butygin         return getResult();
1354917e4519SIvan Butygin       }
1355917e4519SIvan Butygin     }
1356917e4519SIvan Butygin     llvm_unreachable("unknown cmpi predicate kind");
1357917e4519SIvan Butygin   }
1358917e4519SIvan Butygin 
13598c08f21bSMogball   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1360917e4519SIvan Butygin   if (!lhs)
13618c08f21bSMogball     return {};
13628c08f21bSMogball 
1363917e4519SIvan Butygin   // We are moving constants to the right side; So if lhs is constant rhs is
1364917e4519SIvan Butygin   // guaranteed to be a constant.
1365917e4519SIvan Butygin   auto rhs = operands.back().cast<IntegerAttr>();
1366917e4519SIvan Butygin 
13678c08f21bSMogball   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
13688c08f21bSMogball   return BoolAttr::get(getContext(), val);
13698c08f21bSMogball }
13708c08f21bSMogball 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)13712af81c69SWilliam S. Moses void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
13722af81c69SWilliam S. Moses                                                 MLIRContext *context) {
13732af81c69SWilliam S. Moses   patterns.insert<CmpIExtSI, CmpIExtUI>(context);
13742af81c69SWilliam S. Moses }
13752af81c69SWilliam S. Moses 
13768c08f21bSMogball //===----------------------------------------------------------------------===//
13778c08f21bSMogball // CmpFOp
13788c08f21bSMogball //===----------------------------------------------------------------------===//
13798c08f21bSMogball 
13808c08f21bSMogball /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
13818c08f21bSMogball /// comparison predicates.
applyCmpPredicate(arith::CmpFPredicate predicate,const APFloat & lhs,const APFloat & rhs)13828c08f21bSMogball bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
13838c08f21bSMogball                                     const APFloat &lhs, const APFloat &rhs) {
13848c08f21bSMogball   auto cmpResult = lhs.compare(rhs);
13858c08f21bSMogball   switch (predicate) {
13868c08f21bSMogball   case arith::CmpFPredicate::AlwaysFalse:
13878c08f21bSMogball     return false;
13888c08f21bSMogball   case arith::CmpFPredicate::OEQ:
13898c08f21bSMogball     return cmpResult == APFloat::cmpEqual;
13908c08f21bSMogball   case arith::CmpFPredicate::OGT:
13918c08f21bSMogball     return cmpResult == APFloat::cmpGreaterThan;
13928c08f21bSMogball   case arith::CmpFPredicate::OGE:
13938c08f21bSMogball     return cmpResult == APFloat::cmpGreaterThan ||
13948c08f21bSMogball            cmpResult == APFloat::cmpEqual;
13958c08f21bSMogball   case arith::CmpFPredicate::OLT:
13968c08f21bSMogball     return cmpResult == APFloat::cmpLessThan;
13978c08f21bSMogball   case arith::CmpFPredicate::OLE:
13988c08f21bSMogball     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
13998c08f21bSMogball   case arith::CmpFPredicate::ONE:
14008c08f21bSMogball     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
14018c08f21bSMogball   case arith::CmpFPredicate::ORD:
14028c08f21bSMogball     return cmpResult != APFloat::cmpUnordered;
14038c08f21bSMogball   case arith::CmpFPredicate::UEQ:
14048c08f21bSMogball     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
14058c08f21bSMogball   case arith::CmpFPredicate::UGT:
14068c08f21bSMogball     return cmpResult == APFloat::cmpUnordered ||
14078c08f21bSMogball            cmpResult == APFloat::cmpGreaterThan;
14088c08f21bSMogball   case arith::CmpFPredicate::UGE:
14098c08f21bSMogball     return cmpResult == APFloat::cmpUnordered ||
14108c08f21bSMogball            cmpResult == APFloat::cmpGreaterThan ||
14118c08f21bSMogball            cmpResult == APFloat::cmpEqual;
14128c08f21bSMogball   case arith::CmpFPredicate::ULT:
14138c08f21bSMogball     return cmpResult == APFloat::cmpUnordered ||
14148c08f21bSMogball            cmpResult == APFloat::cmpLessThan;
14158c08f21bSMogball   case arith::CmpFPredicate::ULE:
14168c08f21bSMogball     return cmpResult == APFloat::cmpUnordered ||
14178c08f21bSMogball            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
14188c08f21bSMogball   case arith::CmpFPredicate::UNE:
14198c08f21bSMogball     return cmpResult != APFloat::cmpEqual;
14208c08f21bSMogball   case arith::CmpFPredicate::UNO:
14218c08f21bSMogball     return cmpResult == APFloat::cmpUnordered;
14228c08f21bSMogball   case arith::CmpFPredicate::AlwaysTrue:
14238c08f21bSMogball     return true;
14248c08f21bSMogball   }
14258c08f21bSMogball   llvm_unreachable("unknown cmpf predicate kind");
14268c08f21bSMogball }
14278c08f21bSMogball 
fold(ArrayRef<Attribute> operands)14288c08f21bSMogball OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
14298c08f21bSMogball   assert(operands.size() == 2 && "cmpf takes two operands");
14308c08f21bSMogball 
14318c08f21bSMogball   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
14328c08f21bSMogball   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
14338c08f21bSMogball 
1434f6fab68cSChristian Sigg   // If one operand is NaN, making them both NaN does not change the result.
1435f6fab68cSChristian Sigg   if (lhs && lhs.getValue().isNaN())
1436f6fab68cSChristian Sigg     rhs = lhs;
1437f6fab68cSChristian Sigg   if (rhs && rhs.getValue().isNaN())
1438f6fab68cSChristian Sigg     lhs = rhs;
1439f6fab68cSChristian Sigg 
14408c08f21bSMogball   if (!lhs || !rhs)
14418c08f21bSMogball     return {};
14428c08f21bSMogball 
14438c08f21bSMogball   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
14448c08f21bSMogball   return BoolAttr::get(getContext(), val);
14458c08f21bSMogball }
14468c08f21bSMogball 
14471b2a1f84SWilliam S. Moses class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
14481b2a1f84SWilliam S. Moses public:
14491b2a1f84SWilliam S. Moses   using OpRewritePattern<CmpFOp>::OpRewritePattern;
14501b2a1f84SWilliam S. Moses 
convertToIntegerPredicate(CmpFPredicate pred,bool isUnsigned)14511b2a1f84SWilliam S. Moses   static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
14521b2a1f84SWilliam S. Moses                                                  bool isUnsigned) {
14531b2a1f84SWilliam S. Moses     using namespace arith;
14541b2a1f84SWilliam S. Moses     switch (pred) {
14551b2a1f84SWilliam S. Moses     case CmpFPredicate::UEQ:
14561b2a1f84SWilliam S. Moses     case CmpFPredicate::OEQ:
14571b2a1f84SWilliam S. Moses       return CmpIPredicate::eq;
14581b2a1f84SWilliam S. Moses     case CmpFPredicate::UGT:
14591b2a1f84SWilliam S. Moses     case CmpFPredicate::OGT:
14601b2a1f84SWilliam S. Moses       return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
14611b2a1f84SWilliam S. Moses     case CmpFPredicate::UGE:
14621b2a1f84SWilliam S. Moses     case CmpFPredicate::OGE:
14631b2a1f84SWilliam S. Moses       return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
14641b2a1f84SWilliam S. Moses     case CmpFPredicate::ULT:
14651b2a1f84SWilliam S. Moses     case CmpFPredicate::OLT:
14661b2a1f84SWilliam S. Moses       return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
14671b2a1f84SWilliam S. Moses     case CmpFPredicate::ULE:
14681b2a1f84SWilliam S. Moses     case CmpFPredicate::OLE:
14691b2a1f84SWilliam S. Moses       return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
14701b2a1f84SWilliam S. Moses     case CmpFPredicate::UNE:
14711b2a1f84SWilliam S. Moses     case CmpFPredicate::ONE:
14721b2a1f84SWilliam S. Moses       return CmpIPredicate::ne;
14731b2a1f84SWilliam S. Moses     default:
14741b2a1f84SWilliam S. Moses       llvm_unreachable("Unexpected predicate!");
14751b2a1f84SWilliam S. Moses     }
14761b2a1f84SWilliam S. Moses   }
14771b2a1f84SWilliam S. Moses 
matchAndRewrite(CmpFOp op,PatternRewriter & rewriter) const14781b2a1f84SWilliam S. Moses   LogicalResult matchAndRewrite(CmpFOp op,
14791b2a1f84SWilliam S. Moses                                 PatternRewriter &rewriter) const override {
14801b2a1f84SWilliam S. Moses     FloatAttr flt;
14811b2a1f84SWilliam S. Moses     if (!matchPattern(op.getRhs(), m_Constant(&flt)))
14821b2a1f84SWilliam S. Moses       return failure();
14831b2a1f84SWilliam S. Moses 
14841b2a1f84SWilliam S. Moses     const APFloat &rhs = flt.getValue();
14851b2a1f84SWilliam S. Moses 
14861b2a1f84SWilliam S. Moses     // Don't attempt to fold a nan.
14871b2a1f84SWilliam S. Moses     if (rhs.isNaN())
14881b2a1f84SWilliam S. Moses       return failure();
14891b2a1f84SWilliam S. Moses 
14901b2a1f84SWilliam S. Moses     // Get the width of the mantissa.  We don't want to hack on conversions that
14911b2a1f84SWilliam S. Moses     // might lose information from the integer, e.g. "i64 -> float"
14921b2a1f84SWilliam S. Moses     FloatType floatTy = op.getRhs().getType().cast<FloatType>();
14931b2a1f84SWilliam S. Moses     int mantissaWidth = floatTy.getFPMantissaWidth();
14941b2a1f84SWilliam S. Moses     if (mantissaWidth <= 0)
14951b2a1f84SWilliam S. Moses       return failure();
14961b2a1f84SWilliam S. Moses 
14971b2a1f84SWilliam S. Moses     bool isUnsigned;
14981b2a1f84SWilliam S. Moses     Value intVal;
14991b2a1f84SWilliam S. Moses 
15001b2a1f84SWilliam S. Moses     if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
15011b2a1f84SWilliam S. Moses       isUnsigned = false;
15021b2a1f84SWilliam S. Moses       intVal = si.getIn();
15031b2a1f84SWilliam S. Moses     } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
15041b2a1f84SWilliam S. Moses       isUnsigned = true;
15051b2a1f84SWilliam S. Moses       intVal = ui.getIn();
15061b2a1f84SWilliam S. Moses     } else {
15071b2a1f84SWilliam S. Moses       return failure();
15081b2a1f84SWilliam S. Moses     }
15091b2a1f84SWilliam S. Moses 
15101b2a1f84SWilliam S. Moses     // Check to see that the input is converted from an integer type that is
15111b2a1f84SWilliam S. Moses     // small enough that preserves all bits.
15121b2a1f84SWilliam S. Moses     auto intTy = intVal.getType().cast<IntegerType>();
15131b2a1f84SWilliam S. Moses     auto intWidth = intTy.getWidth();
15141b2a1f84SWilliam S. Moses 
15151b2a1f84SWilliam S. Moses     // Number of bits representing values, as opposed to the sign
15161b2a1f84SWilliam S. Moses     auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
15171b2a1f84SWilliam S. Moses 
15181b2a1f84SWilliam S. Moses     // Following test does NOT adjust intWidth downwards for signed inputs,
15191b2a1f84SWilliam S. Moses     // because the most negative value still requires all the mantissa bits
15201b2a1f84SWilliam S. Moses     // to distinguish it from one less than that value.
15211b2a1f84SWilliam S. Moses     if ((int)intWidth > mantissaWidth) {
15221b2a1f84SWilliam S. Moses       // Conversion would lose accuracy. Check if loss can impact comparison.
15231b2a1f84SWilliam S. Moses       int exponent = ilogb(rhs);
15241b2a1f84SWilliam S. Moses       if (exponent == APFloat::IEK_Inf) {
15251b2a1f84SWilliam S. Moses         int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
15261b2a1f84SWilliam S. Moses         if (maxExponent < (int)valueBits) {
15271b2a1f84SWilliam S. Moses           // Conversion could create infinity.
15281b2a1f84SWilliam S. Moses           return failure();
15291b2a1f84SWilliam S. Moses         }
15301b2a1f84SWilliam S. Moses       } else {
15311b2a1f84SWilliam S. Moses         // Note that if rhs is zero or NaN, then Exp is negative
15321b2a1f84SWilliam S. Moses         // and first condition is trivially false.
15331b2a1f84SWilliam S. Moses         if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
15341b2a1f84SWilliam S. Moses           // Conversion could affect comparison.
15351b2a1f84SWilliam S. Moses           return failure();
15361b2a1f84SWilliam S. Moses         }
15371b2a1f84SWilliam S. Moses       }
15381b2a1f84SWilliam S. Moses     }
15391b2a1f84SWilliam S. Moses 
15401b2a1f84SWilliam S. Moses     // Convert to equivalent cmpi predicate
15411b2a1f84SWilliam S. Moses     CmpIPredicate pred;
15421b2a1f84SWilliam S. Moses     switch (op.getPredicate()) {
15431b2a1f84SWilliam S. Moses     case CmpFPredicate::ORD:
15441b2a1f84SWilliam S. Moses       // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
15451b2a1f84SWilliam S. Moses       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
15461b2a1f84SWilliam S. Moses                                                  /*width=*/1);
15471b2a1f84SWilliam S. Moses       return success();
15481b2a1f84SWilliam S. Moses     case CmpFPredicate::UNO:
15491b2a1f84SWilliam S. Moses       // Int to fp conversion doesn't create a nan (uno checks either is a nan)
15501b2a1f84SWilliam S. Moses       rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
15511b2a1f84SWilliam S. Moses                                                  /*width=*/1);
15521b2a1f84SWilliam S. Moses       return success();
15531b2a1f84SWilliam S. Moses     default:
15541b2a1f84SWilliam S. Moses       pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
15551b2a1f84SWilliam S. Moses       break;
15561b2a1f84SWilliam S. Moses     }
15571b2a1f84SWilliam S. Moses 
15581b2a1f84SWilliam S. Moses     if (!isUnsigned) {
15591b2a1f84SWilliam S. Moses       // If the rhs value is > SignedMax, fold the comparison.  This handles
15601b2a1f84SWilliam S. Moses       // +INF and large values.
15611b2a1f84SWilliam S. Moses       APFloat signedMax(rhs.getSemantics());
15621b2a1f84SWilliam S. Moses       signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
15631b2a1f84SWilliam S. Moses                                  APFloat::rmNearestTiesToEven);
15641b2a1f84SWilliam S. Moses       if (signedMax < rhs) { // smax < 13123.0
15651b2a1f84SWilliam S. Moses         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
15661b2a1f84SWilliam S. Moses             pred == CmpIPredicate::sle)
15671b2a1f84SWilliam S. Moses           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
15681b2a1f84SWilliam S. Moses                                                      /*width=*/1);
15691b2a1f84SWilliam S. Moses         else
15701b2a1f84SWilliam S. Moses           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
15711b2a1f84SWilliam S. Moses                                                      /*width=*/1);
15721b2a1f84SWilliam S. Moses         return success();
15731b2a1f84SWilliam S. Moses       }
15741b2a1f84SWilliam S. Moses     } else {
15751b2a1f84SWilliam S. Moses       // If the rhs value is > UnsignedMax, fold the comparison. This handles
15761b2a1f84SWilliam S. Moses       // +INF and large values.
15771b2a1f84SWilliam S. Moses       APFloat unsignedMax(rhs.getSemantics());
15781b2a1f84SWilliam S. Moses       unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
15791b2a1f84SWilliam S. Moses                                    APFloat::rmNearestTiesToEven);
15801b2a1f84SWilliam S. Moses       if (unsignedMax < rhs) { // umax < 13123.0
15811b2a1f84SWilliam S. Moses         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
15821b2a1f84SWilliam S. Moses             pred == CmpIPredicate::ule)
15831b2a1f84SWilliam S. Moses           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
15841b2a1f84SWilliam S. Moses                                                      /*width=*/1);
15851b2a1f84SWilliam S. Moses         else
15861b2a1f84SWilliam S. Moses           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
15871b2a1f84SWilliam S. Moses                                                      /*width=*/1);
15881b2a1f84SWilliam S. Moses         return success();
15891b2a1f84SWilliam S. Moses       }
15901b2a1f84SWilliam S. Moses     }
15911b2a1f84SWilliam S. Moses 
15921b2a1f84SWilliam S. Moses     if (!isUnsigned) {
15931b2a1f84SWilliam S. Moses       // See if the rhs value is < SignedMin.
15941b2a1f84SWilliam S. Moses       APFloat signedMin(rhs.getSemantics());
15951b2a1f84SWilliam S. Moses       signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
15961b2a1f84SWilliam S. Moses                                  APFloat::rmNearestTiesToEven);
15971b2a1f84SWilliam S. Moses       if (signedMin > rhs) { // smin > 12312.0
15981b2a1f84SWilliam S. Moses         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
15991b2a1f84SWilliam S. Moses             pred == CmpIPredicate::sge)
16001b2a1f84SWilliam S. Moses           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
16011b2a1f84SWilliam S. Moses                                                      /*width=*/1);
16021b2a1f84SWilliam S. Moses         else
16031b2a1f84SWilliam S. Moses           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
16041b2a1f84SWilliam S. Moses                                                      /*width=*/1);
16051b2a1f84SWilliam S. Moses         return success();
16061b2a1f84SWilliam S. Moses       }
16071b2a1f84SWilliam S. Moses     } else {
16081b2a1f84SWilliam S. Moses       // See if the rhs value is < UnsignedMin.
16091b2a1f84SWilliam S. Moses       APFloat unsignedMin(rhs.getSemantics());
16101b2a1f84SWilliam S. Moses       unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
16111b2a1f84SWilliam S. Moses                                    APFloat::rmNearestTiesToEven);
16121b2a1f84SWilliam S. Moses       if (unsignedMin > rhs) { // umin > 12312.0
16131b2a1f84SWilliam S. Moses         if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
16141b2a1f84SWilliam S. Moses             pred == CmpIPredicate::uge)
16151b2a1f84SWilliam S. Moses           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
16161b2a1f84SWilliam S. Moses                                                      /*width=*/1);
16171b2a1f84SWilliam S. Moses         else
16181b2a1f84SWilliam S. Moses           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
16191b2a1f84SWilliam S. Moses                                                      /*width=*/1);
16201b2a1f84SWilliam S. Moses         return success();
16211b2a1f84SWilliam S. Moses       }
16221b2a1f84SWilliam S. Moses     }
16231b2a1f84SWilliam S. Moses 
16241b2a1f84SWilliam S. Moses     // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
16251b2a1f84SWilliam S. Moses     // [0, UMAX], but it may still be fractional.  See if it is fractional by
16261b2a1f84SWilliam S. Moses     // casting the FP value to the integer value and back, checking for
16271b2a1f84SWilliam S. Moses     // equality. Don't do this for zero, because -0.0 is not fractional.
16281b2a1f84SWilliam S. Moses     bool ignored;
16291b2a1f84SWilliam S. Moses     APSInt rhsInt(intWidth, isUnsigned);
16301b2a1f84SWilliam S. Moses     if (APFloat::opInvalidOp ==
16311b2a1f84SWilliam S. Moses         rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
16321b2a1f84SWilliam S. Moses       // Undefined behavior invoked - the destination type can't represent
16331b2a1f84SWilliam S. Moses       // the input constant.
16341b2a1f84SWilliam S. Moses       return failure();
16351b2a1f84SWilliam S. Moses     }
16361b2a1f84SWilliam S. Moses 
16371b2a1f84SWilliam S. Moses     if (!rhs.isZero()) {
16381b2a1f84SWilliam S. Moses       APFloat apf(floatTy.getFloatSemantics(),
16391b2a1f84SWilliam S. Moses                   APInt::getZero(floatTy.getWidth()));
16401b2a1f84SWilliam S. Moses       apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
16411b2a1f84SWilliam S. Moses 
16421b2a1f84SWilliam S. Moses       bool equal = apf == rhs;
16431b2a1f84SWilliam S. Moses       if (!equal) {
16441b2a1f84SWilliam S. Moses         // If we had a comparison against a fractional value, we have to adjust
16451b2a1f84SWilliam S. Moses         // the compare predicate and sometimes the value.  rhsInt is rounded
16461b2a1f84SWilliam S. Moses         // towards zero at this point.
16471b2a1f84SWilliam S. Moses         switch (pred) {
16481b2a1f84SWilliam S. Moses         case CmpIPredicate::ne: // (float)int != 4.4   --> true
16491b2a1f84SWilliam S. Moses           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
16501b2a1f84SWilliam S. Moses                                                      /*width=*/1);
16511b2a1f84SWilliam S. Moses           return success();
16521b2a1f84SWilliam S. Moses         case CmpIPredicate::eq: // (float)int == 4.4   --> false
16531b2a1f84SWilliam S. Moses           rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
16541b2a1f84SWilliam S. Moses                                                      /*width=*/1);
16551b2a1f84SWilliam S. Moses           return success();
16561b2a1f84SWilliam S. Moses         case CmpIPredicate::ule:
16571b2a1f84SWilliam S. Moses           // (float)int <= 4.4   --> int <= 4
16581b2a1f84SWilliam S. Moses           // (float)int <= -4.4  --> false
16591b2a1f84SWilliam S. Moses           if (rhs.isNegative()) {
16601b2a1f84SWilliam S. Moses             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
16611b2a1f84SWilliam S. Moses                                                        /*width=*/1);
16621b2a1f84SWilliam S. Moses             return success();
16631b2a1f84SWilliam S. Moses           }
16641b2a1f84SWilliam S. Moses           break;
16651b2a1f84SWilliam S. Moses         case CmpIPredicate::sle:
16661b2a1f84SWilliam S. Moses           // (float)int <= 4.4   --> int <= 4
16671b2a1f84SWilliam S. Moses           // (float)int <= -4.4  --> int < -4
16681b2a1f84SWilliam S. Moses           if (rhs.isNegative())
16691b2a1f84SWilliam S. Moses             pred = CmpIPredicate::slt;
16701b2a1f84SWilliam S. Moses           break;
16711b2a1f84SWilliam S. Moses         case CmpIPredicate::ult:
16721b2a1f84SWilliam S. Moses           // (float)int < -4.4   --> false
16731b2a1f84SWilliam S. Moses           // (float)int < 4.4    --> int <= 4
16741b2a1f84SWilliam S. Moses           if (rhs.isNegative()) {
16751b2a1f84SWilliam S. Moses             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
16761b2a1f84SWilliam S. Moses                                                        /*width=*/1);
16771b2a1f84SWilliam S. Moses             return success();
16781b2a1f84SWilliam S. Moses           }
16791b2a1f84SWilliam S. Moses           pred = CmpIPredicate::ule;
16801b2a1f84SWilliam S. Moses           break;
16811b2a1f84SWilliam S. Moses         case CmpIPredicate::slt:
16821b2a1f84SWilliam S. Moses           // (float)int < -4.4   --> int < -4
16831b2a1f84SWilliam S. Moses           // (float)int < 4.4    --> int <= 4
16841b2a1f84SWilliam S. Moses           if (!rhs.isNegative())
16851b2a1f84SWilliam S. Moses             pred = CmpIPredicate::sle;
16861b2a1f84SWilliam S. Moses           break;
16871b2a1f84SWilliam S. Moses         case CmpIPredicate::ugt:
16881b2a1f84SWilliam S. Moses           // (float)int > 4.4    --> int > 4
16891b2a1f84SWilliam S. Moses           // (float)int > -4.4   --> true
16901b2a1f84SWilliam S. Moses           if (rhs.isNegative()) {
16911b2a1f84SWilliam S. Moses             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
16921b2a1f84SWilliam S. Moses                                                        /*width=*/1);
16931b2a1f84SWilliam S. Moses             return success();
16941b2a1f84SWilliam S. Moses           }
16951b2a1f84SWilliam S. Moses           break;
16961b2a1f84SWilliam S. Moses         case CmpIPredicate::sgt:
16971b2a1f84SWilliam S. Moses           // (float)int > 4.4    --> int > 4
16981b2a1f84SWilliam S. Moses           // (float)int > -4.4   --> int >= -4
16991b2a1f84SWilliam S. Moses           if (rhs.isNegative())
17001b2a1f84SWilliam S. Moses             pred = CmpIPredicate::sge;
17011b2a1f84SWilliam S. Moses           break;
17021b2a1f84SWilliam S. Moses         case CmpIPredicate::uge:
17031b2a1f84SWilliam S. Moses           // (float)int >= -4.4   --> true
17041b2a1f84SWilliam S. Moses           // (float)int >= 4.4    --> int > 4
17051b2a1f84SWilliam S. Moses           if (rhs.isNegative()) {
17061b2a1f84SWilliam S. Moses             rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
17071b2a1f84SWilliam S. Moses                                                        /*width=*/1);
17081b2a1f84SWilliam S. Moses             return success();
17091b2a1f84SWilliam S. Moses           }
17101b2a1f84SWilliam S. Moses           pred = CmpIPredicate::ugt;
17111b2a1f84SWilliam S. Moses           break;
17121b2a1f84SWilliam S. Moses         case CmpIPredicate::sge:
17131b2a1f84SWilliam S. Moses           // (float)int >= -4.4   --> int >= -4
17141b2a1f84SWilliam S. Moses           // (float)int >= 4.4    --> int > 4
17151b2a1f84SWilliam S. Moses           if (!rhs.isNegative())
17161b2a1f84SWilliam S. Moses             pred = CmpIPredicate::sgt;
17171b2a1f84SWilliam S. Moses           break;
17181b2a1f84SWilliam S. Moses         }
17191b2a1f84SWilliam S. Moses       }
17201b2a1f84SWilliam S. Moses     }
17211b2a1f84SWilliam S. Moses 
17221b2a1f84SWilliam S. Moses     // Lower this FP comparison into an appropriate integer version of the
17231b2a1f84SWilliam S. Moses     // comparison.
17241b2a1f84SWilliam S. Moses     rewriter.replaceOpWithNewOp<CmpIOp>(
17251b2a1f84SWilliam S. Moses         op, pred, intVal,
17261b2a1f84SWilliam S. Moses         rewriter.create<ConstantOp>(
17271b2a1f84SWilliam S. Moses             op.getLoc(), intVal.getType(),
17281b2a1f84SWilliam S. Moses             rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
17291b2a1f84SWilliam S. Moses     return success();
17301b2a1f84SWilliam S. Moses   }
17311b2a1f84SWilliam S. Moses };
17321b2a1f84SWilliam S. Moses 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)17331b2a1f84SWilliam S. Moses void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
17341b2a1f84SWilliam S. Moses                                                 MLIRContext *context) {
17351b2a1f84SWilliam S. Moses   patterns.insert<CmpFIntToFPConst>(context);
17361b2a1f84SWilliam S. Moses }
17371b2a1f84SWilliam S. Moses 
17388c08f21bSMogball //===----------------------------------------------------------------------===//
1739dec8af70SRiver Riddle // SelectOp
1740dec8af70SRiver Riddle //===----------------------------------------------------------------------===//
1741dec8af70SRiver Riddle 
1742dec8af70SRiver Riddle // Transforms a select of a boolean to arithmetic operations
1743dec8af70SRiver Riddle //
1744dec8af70SRiver Riddle //  arith.select %arg, %x, %y : i1
1745dec8af70SRiver Riddle //
1746dec8af70SRiver Riddle //  becomes
1747dec8af70SRiver Riddle //
1748dec8af70SRiver Riddle //  and(%arg, %x) or and(!%arg, %y)
1749dec8af70SRiver Riddle struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1750dec8af70SRiver Riddle   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1751dec8af70SRiver Riddle 
matchAndRewriteSelectI1Simplify1752dec8af70SRiver Riddle   LogicalResult matchAndRewrite(arith::SelectOp op,
1753dec8af70SRiver Riddle                                 PatternRewriter &rewriter) const override {
1754dec8af70SRiver Riddle     if (!op.getType().isInteger(1))
1755dec8af70SRiver Riddle       return failure();
1756dec8af70SRiver Riddle 
1757dec8af70SRiver Riddle     Value falseConstant =
1758dec8af70SRiver Riddle         rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1759dec8af70SRiver Riddle     Value notCondition = rewriter.create<arith::XOrIOp>(
1760dec8af70SRiver Riddle         op.getLoc(), op.getCondition(), falseConstant);
1761dec8af70SRiver Riddle 
1762dec8af70SRiver Riddle     Value trueVal = rewriter.create<arith::AndIOp>(
1763dec8af70SRiver Riddle         op.getLoc(), op.getCondition(), op.getTrueValue());
1764dec8af70SRiver Riddle     Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1765dec8af70SRiver Riddle                                                     op.getFalseValue());
1766dec8af70SRiver Riddle     rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1767dec8af70SRiver Riddle     return success();
1768dec8af70SRiver Riddle   }
1769dec8af70SRiver Riddle };
1770dec8af70SRiver Riddle 
1771dec8af70SRiver Riddle //  select %arg, %c1, %c0 => extui %arg
1772dec8af70SRiver Riddle struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1773dec8af70SRiver Riddle   using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1774dec8af70SRiver Riddle 
matchAndRewriteSelectToExtUI1775dec8af70SRiver Riddle   LogicalResult matchAndRewrite(arith::SelectOp op,
1776dec8af70SRiver Riddle                                 PatternRewriter &rewriter) const override {
1777dec8af70SRiver Riddle     // Cannot extui i1 to i1, or i1 to f32
1778dec8af70SRiver Riddle     if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1779dec8af70SRiver Riddle       return failure();
1780dec8af70SRiver Riddle 
1781dec8af70SRiver Riddle     // select %x, c1, %c0 => extui %arg
178210c9ecceSjacquesguan     if (matchPattern(op.getTrueValue(), m_One()) &&
178310c9ecceSjacquesguan         matchPattern(op.getFalseValue(), m_Zero())) {
1784dec8af70SRiver Riddle       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1785dec8af70SRiver Riddle                                                   op.getCondition());
1786dec8af70SRiver Riddle       return success();
1787dec8af70SRiver Riddle     }
1788dec8af70SRiver Riddle 
1789dec8af70SRiver Riddle     // select %x, c0, %c1 => extui (xor %arg, true)
179010c9ecceSjacquesguan     if (matchPattern(op.getTrueValue(), m_Zero()) &&
179110c9ecceSjacquesguan         matchPattern(op.getFalseValue(), m_One())) {
1792dec8af70SRiver Riddle       rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1793dec8af70SRiver Riddle           op, op.getType(),
1794dec8af70SRiver Riddle           rewriter.create<arith::XOrIOp>(
1795dec8af70SRiver Riddle               op.getLoc(), op.getCondition(),
1796dec8af70SRiver Riddle               rewriter.create<arith::ConstantIntOp>(
1797dec8af70SRiver Riddle                   op.getLoc(), 1, op.getCondition().getType())));
1798dec8af70SRiver Riddle       return success();
1799dec8af70SRiver Riddle     }
1800dec8af70SRiver Riddle 
1801dec8af70SRiver Riddle     return failure();
1802dec8af70SRiver Riddle   }
1803dec8af70SRiver Riddle };
1804dec8af70SRiver Riddle 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1805dec8af70SRiver Riddle void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1806dec8af70SRiver Riddle                                                   MLIRContext *context) {
1807b4e0507cSTres Popp   results.add<SelectI1Simplify, SelectToExtUI>(context);
1808dec8af70SRiver Riddle }
1809dec8af70SRiver Riddle 
fold(ArrayRef<Attribute> operands)1810dec8af70SRiver Riddle OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1811dec8af70SRiver Riddle   Value trueVal = getTrueValue();
1812dec8af70SRiver Riddle   Value falseVal = getFalseValue();
1813dec8af70SRiver Riddle   if (trueVal == falseVal)
1814dec8af70SRiver Riddle     return trueVal;
1815dec8af70SRiver Riddle 
1816dec8af70SRiver Riddle   Value condition = getCondition();
1817dec8af70SRiver Riddle 
1818dec8af70SRiver Riddle   // select true, %0, %1 => %0
1819dec8af70SRiver Riddle   if (matchPattern(condition, m_One()))
1820dec8af70SRiver Riddle     return trueVal;
1821dec8af70SRiver Riddle 
1822dec8af70SRiver Riddle   // select false, %0, %1 => %1
1823dec8af70SRiver Riddle   if (matchPattern(condition, m_Zero()))
1824dec8af70SRiver Riddle     return falseVal;
1825dec8af70SRiver Riddle 
1826dec8af70SRiver Riddle   // select %x, true, false => %x
182710c9ecceSjacquesguan   if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
182810c9ecceSjacquesguan       matchPattern(getFalseValue(), m_Zero()))
1829dec8af70SRiver Riddle     return condition;
1830dec8af70SRiver Riddle 
1831dec8af70SRiver Riddle   if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1832dec8af70SRiver Riddle     auto pred = cmp.getPredicate();
1833dec8af70SRiver Riddle     if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1834dec8af70SRiver Riddle       auto cmpLhs = cmp.getLhs();
1835dec8af70SRiver Riddle       auto cmpRhs = cmp.getRhs();
1836dec8af70SRiver Riddle 
1837dec8af70SRiver Riddle       // %0 = arith.cmpi eq, %arg0, %arg1
1838dec8af70SRiver Riddle       // %1 = arith.select %0, %arg0, %arg1 => %arg1
1839dec8af70SRiver Riddle 
1840dec8af70SRiver Riddle       // %0 = arith.cmpi ne, %arg0, %arg1
1841dec8af70SRiver Riddle       // %1 = arith.select %0, %arg0, %arg1 => %arg0
1842dec8af70SRiver Riddle 
1843dec8af70SRiver Riddle       if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1844dec8af70SRiver Riddle           (cmpRhs == trueVal && cmpLhs == falseVal))
1845dec8af70SRiver Riddle         return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1846dec8af70SRiver Riddle     }
1847dec8af70SRiver Riddle   }
1848dec8af70SRiver Riddle   return nullptr;
1849dec8af70SRiver Riddle }
1850dec8af70SRiver Riddle 
parse(OpAsmParser & parser,OperationState & result)18512418cd92SRiver Riddle ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1852dec8af70SRiver Riddle   Type conditionType, resultType;
1853e13d23bcSMarkus Böck   SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1854dec8af70SRiver Riddle   if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1855dec8af70SRiver Riddle       parser.parseOptionalAttrDict(result.attributes) ||
1856dec8af70SRiver Riddle       parser.parseColonType(resultType))
1857dec8af70SRiver Riddle     return failure();
1858dec8af70SRiver Riddle 
1859dec8af70SRiver Riddle   // Check for the explicit condition type if this is a masked tensor or vector.
1860dec8af70SRiver Riddle   if (succeeded(parser.parseOptionalComma())) {
1861dec8af70SRiver Riddle     conditionType = resultType;
1862dec8af70SRiver Riddle     if (parser.parseType(resultType))
1863dec8af70SRiver Riddle       return failure();
1864dec8af70SRiver Riddle   } else {
1865dec8af70SRiver Riddle     conditionType = parser.getBuilder().getI1Type();
1866dec8af70SRiver Riddle   }
1867dec8af70SRiver Riddle 
1868dec8af70SRiver Riddle   result.addTypes(resultType);
1869dec8af70SRiver Riddle   return parser.resolveOperands(operands,
1870dec8af70SRiver Riddle                                 {conditionType, resultType, resultType},
1871dec8af70SRiver Riddle                                 parser.getNameLoc(), result.operands);
1872dec8af70SRiver Riddle }
1873dec8af70SRiver Riddle 
print(OpAsmPrinter & p)18742418cd92SRiver Riddle void arith::SelectOp::print(OpAsmPrinter &p) {
18752418cd92SRiver Riddle   p << " " << getOperands();
18762418cd92SRiver Riddle   p.printOptionalAttrDict((*this)->getAttrs());
18772418cd92SRiver Riddle   p << " : ";
18782418cd92SRiver Riddle   if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
18792418cd92SRiver Riddle     p << condType << ", ";
18802418cd92SRiver Riddle   p << getType();
18812418cd92SRiver Riddle }
18822418cd92SRiver Riddle 
verify()1883dec8af70SRiver Riddle LogicalResult arith::SelectOp::verify() {
1884dec8af70SRiver Riddle   Type conditionType = getCondition().getType();
1885dec8af70SRiver Riddle   if (conditionType.isSignlessInteger(1))
1886dec8af70SRiver Riddle     return success();
1887dec8af70SRiver Riddle 
1888dec8af70SRiver Riddle   // If the result type is a vector or tensor, the type can be a mask with the
1889dec8af70SRiver Riddle   // same elements.
1890dec8af70SRiver Riddle   Type resultType = getType();
1891dec8af70SRiver Riddle   if (!resultType.isa<TensorType, VectorType>())
1892dec8af70SRiver Riddle     return emitOpError() << "expected condition to be a signless i1, but got "
1893dec8af70SRiver Riddle                          << conditionType;
1894dec8af70SRiver Riddle   Type shapedConditionType = getI1SameShape(resultType);
1895dec8af70SRiver Riddle   if (conditionType != shapedConditionType) {
1896dec8af70SRiver Riddle     return emitOpError() << "expected condition type to have the same shape "
1897dec8af70SRiver Riddle                             "as the result type, expected "
1898dec8af70SRiver Riddle                          << shapedConditionType << ", but got "
1899dec8af70SRiver Riddle                          << conditionType;
1900dec8af70SRiver Riddle   }
1901dec8af70SRiver Riddle   return success();
1902dec8af70SRiver Riddle }
1903db31da27SWilliam S. Moses //===----------------------------------------------------------------------===//
1904db31da27SWilliam S. Moses // ShLIOp
1905db31da27SWilliam S. Moses //===----------------------------------------------------------------------===//
1906db31da27SWilliam S. Moses 
fold(ArrayRef<Attribute> operands)1907db31da27SWilliam S. Moses OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1908db31da27SWilliam S. Moses   // Don't fold if shifting more than the bit width.
1909db31da27SWilliam S. Moses   bool bounded = false;
191051894cbbSMehdi Amini   auto result = constFoldBinaryOp<IntegerAttr>(
191151894cbbSMehdi Amini       operands, [&](const APInt &a, const APInt &b) {
1912db31da27SWilliam S. Moses         bounded = b.ule(b.getBitWidth());
1913bf62a4b9SMehdi Amini         return a.shl(b);
1914db31da27SWilliam S. Moses       });
1915db31da27SWilliam S. Moses   return bounded ? result : Attribute();
1916db31da27SWilliam S. Moses }
1917dec8af70SRiver Riddle 
1918dec8af70SRiver Riddle //===----------------------------------------------------------------------===//
191955053205Sjacquesguan // ShRUIOp
192055053205Sjacquesguan //===----------------------------------------------------------------------===//
192155053205Sjacquesguan 
fold(ArrayRef<Attribute> operands)192255053205Sjacquesguan OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
192355053205Sjacquesguan   // Don't fold if shifting more than the bit width.
192455053205Sjacquesguan   bool bounded = false;
192555053205Sjacquesguan   auto result = constFoldBinaryOp<IntegerAttr>(
192655053205Sjacquesguan       operands, [&](const APInt &a, const APInt &b) {
192755053205Sjacquesguan         bounded = b.ule(b.getBitWidth());
1928bf62a4b9SMehdi Amini         return a.lshr(b);
192955053205Sjacquesguan       });
193055053205Sjacquesguan   return bounded ? result : Attribute();
193155053205Sjacquesguan }
193255053205Sjacquesguan 
193355053205Sjacquesguan //===----------------------------------------------------------------------===//
193455053205Sjacquesguan // ShRSIOp
193555053205Sjacquesguan //===----------------------------------------------------------------------===//
193655053205Sjacquesguan 
fold(ArrayRef<Attribute> operands)193755053205Sjacquesguan OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
193855053205Sjacquesguan   // Don't fold if shifting more than the bit width.
193955053205Sjacquesguan   bool bounded = false;
194055053205Sjacquesguan   auto result = constFoldBinaryOp<IntegerAttr>(
194155053205Sjacquesguan       operands, [&](const APInt &a, const APInt &b) {
194255053205Sjacquesguan         bounded = b.ule(b.getBitWidth());
1943bf62a4b9SMehdi Amini         return a.ashr(b);
194455053205Sjacquesguan       });
194555053205Sjacquesguan   return bounded ? result : Attribute();
194655053205Sjacquesguan }
194755053205Sjacquesguan 
194855053205Sjacquesguan //===----------------------------------------------------------------------===//
1949a6a583daSWilliam S. Moses // Atomic Enum
1950a6a583daSWilliam S. Moses //===----------------------------------------------------------------------===//
1951a6a583daSWilliam S. Moses 
1952a6a583daSWilliam S. Moses /// Returns the identity value attribute associated with an AtomicRMWKind op.
getIdentityValueAttr(AtomicRMWKind kind,Type resultType,OpBuilder & builder,Location loc)1953a6a583daSWilliam S. Moses Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1954a6a583daSWilliam S. Moses                                             OpBuilder &builder, Location loc) {
1955a6a583daSWilliam S. Moses   switch (kind) {
1956a6a583daSWilliam S. Moses   case AtomicRMWKind::maxf:
1957a6a583daSWilliam S. Moses     return builder.getFloatAttr(
1958a6a583daSWilliam S. Moses         resultType,
1959a6a583daSWilliam S. Moses         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1960a6a583daSWilliam S. Moses                         /*Negative=*/true));
1961a6a583daSWilliam S. Moses   case AtomicRMWKind::addf:
1962a6a583daSWilliam S. Moses   case AtomicRMWKind::addi:
1963a6a583daSWilliam S. Moses   case AtomicRMWKind::maxu:
1964a6a583daSWilliam S. Moses   case AtomicRMWKind::ori:
1965a6a583daSWilliam S. Moses     return builder.getZeroAttr(resultType);
1966a6a583daSWilliam S. Moses   case AtomicRMWKind::andi:
1967a6a583daSWilliam S. Moses     return builder.getIntegerAttr(
1968a6a583daSWilliam S. Moses         resultType,
1969a6a583daSWilliam S. Moses         APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1970a6a583daSWilliam S. Moses   case AtomicRMWKind::maxs:
1971a6a583daSWilliam S. Moses     return builder.getIntegerAttr(
1972a6a583daSWilliam S. Moses         resultType,
1973a6a583daSWilliam S. Moses         APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1974a6a583daSWilliam S. Moses   case AtomicRMWKind::minf:
1975a6a583daSWilliam S. Moses     return builder.getFloatAttr(
1976a6a583daSWilliam S. Moses         resultType,
1977a6a583daSWilliam S. Moses         APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1978a6a583daSWilliam S. Moses                         /*Negative=*/false));
1979a6a583daSWilliam S. Moses   case AtomicRMWKind::mins:
1980a6a583daSWilliam S. Moses     return builder.getIntegerAttr(
1981a6a583daSWilliam S. Moses         resultType,
1982a6a583daSWilliam S. Moses         APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1983a6a583daSWilliam S. Moses   case AtomicRMWKind::minu:
1984a6a583daSWilliam S. Moses     return builder.getIntegerAttr(
1985a6a583daSWilliam S. Moses         resultType,
1986a6a583daSWilliam S. Moses         APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1987a6a583daSWilliam S. Moses   case AtomicRMWKind::muli:
1988a6a583daSWilliam S. Moses     return builder.getIntegerAttr(resultType, 1);
1989a6a583daSWilliam S. Moses   case AtomicRMWKind::mulf:
1990a6a583daSWilliam S. Moses     return builder.getFloatAttr(resultType, 1);
1991a6a583daSWilliam S. Moses   // TODO: Add remaining reduction operations.
1992a6a583daSWilliam S. Moses   default:
1993a6a583daSWilliam S. Moses     (void)emitOptionalError(loc, "Reduction operation type not supported");
1994a6a583daSWilliam S. Moses     break;
1995a6a583daSWilliam S. Moses   }
1996a6a583daSWilliam S. Moses   return nullptr;
1997a6a583daSWilliam S. Moses }
1998a6a583daSWilliam S. Moses 
1999a6a583daSWilliam S. Moses /// Returns the identity value associated with an AtomicRMWKind op.
getIdentityValue(AtomicRMWKind op,Type resultType,OpBuilder & builder,Location loc)2000a6a583daSWilliam S. Moses Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2001a6a583daSWilliam S. Moses                                     OpBuilder &builder, Location loc) {
2002a6a583daSWilliam S. Moses   Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
2003a6a583daSWilliam S. Moses   return builder.create<arith::ConstantOp>(loc, attr);
2004a6a583daSWilliam S. Moses }
2005a6a583daSWilliam S. Moses 
2006a6a583daSWilliam S. Moses /// Return the value obtained by applying the reduction operation kind
2007a6a583daSWilliam S. Moses /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
getReductionOp(AtomicRMWKind op,OpBuilder & builder,Location loc,Value lhs,Value rhs)2008a6a583daSWilliam S. Moses Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2009a6a583daSWilliam S. Moses                                   Location loc, Value lhs, Value rhs) {
2010a6a583daSWilliam S. Moses   switch (op) {
2011a6a583daSWilliam S. Moses   case AtomicRMWKind::addf:
2012a6a583daSWilliam S. Moses     return builder.create<arith::AddFOp>(loc, lhs, rhs);
2013a6a583daSWilliam S. Moses   case AtomicRMWKind::addi:
2014a6a583daSWilliam S. Moses     return builder.create<arith::AddIOp>(loc, lhs, rhs);
2015a6a583daSWilliam S. Moses   case AtomicRMWKind::mulf:
2016a6a583daSWilliam S. Moses     return builder.create<arith::MulFOp>(loc, lhs, rhs);
2017a6a583daSWilliam S. Moses   case AtomicRMWKind::muli:
2018a6a583daSWilliam S. Moses     return builder.create<arith::MulIOp>(loc, lhs, rhs);
2019a6a583daSWilliam S. Moses   case AtomicRMWKind::maxf:
2020a6a583daSWilliam S. Moses     return builder.create<arith::MaxFOp>(loc, lhs, rhs);
2021a6a583daSWilliam S. Moses   case AtomicRMWKind::minf:
2022a6a583daSWilliam S. Moses     return builder.create<arith::MinFOp>(loc, lhs, rhs);
2023a6a583daSWilliam S. Moses   case AtomicRMWKind::maxs:
2024a6a583daSWilliam S. Moses     return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2025a6a583daSWilliam S. Moses   case AtomicRMWKind::mins:
2026a6a583daSWilliam S. Moses     return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2027a6a583daSWilliam S. Moses   case AtomicRMWKind::maxu:
2028a6a583daSWilliam S. Moses     return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2029a6a583daSWilliam S. Moses   case AtomicRMWKind::minu:
2030a6a583daSWilliam S. Moses     return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2031a6a583daSWilliam S. Moses   case AtomicRMWKind::ori:
2032a6a583daSWilliam S. Moses     return builder.create<arith::OrIOp>(loc, lhs, rhs);
2033a6a583daSWilliam S. Moses   case AtomicRMWKind::andi:
2034a6a583daSWilliam S. Moses     return builder.create<arith::AndIOp>(loc, lhs, rhs);
2035a6a583daSWilliam S. Moses   // TODO: Add remaining reduction operations.
2036a6a583daSWilliam S. Moses   default:
2037a6a583daSWilliam S. Moses     (void)emitOptionalError(loc, "Reduction operation type not supported");
2038a6a583daSWilliam S. Moses     break;
2039a6a583daSWilliam S. Moses   }
2040a6a583daSWilliam S. Moses   return nullptr;
2041a6a583daSWilliam S. Moses }
2042a6a583daSWilliam S. Moses 
2043a6a583daSWilliam S. Moses //===----------------------------------------------------------------------===//
20448c08f21bSMogball // TableGen'd op method definitions
20458c08f21bSMogball //===----------------------------------------------------------------------===//
20468c08f21bSMogball 
20478c08f21bSMogball #define GET_OP_CLASSES
20488c08f21bSMogball #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
20498c08f21bSMogball 
20508c08f21bSMogball //===----------------------------------------------------------------------===//
20518c08f21bSMogball // TableGen'd enum attribute definitions
20528c08f21bSMogball //===----------------------------------------------------------------------===//
20538c08f21bSMogball 
20548c08f21bSMogball #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2055