1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
10 #include "mlir/Dialect/CommonFolders.h"
11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/OpImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/IR/TypeUtilities.h"
16 
17 using namespace mlir;
18 using namespace mlir::arith;
19 
20 //===----------------------------------------------------------------------===//
21 // Pattern helpers
22 //===----------------------------------------------------------------------===//
23 
24 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
25                                    Attribute lhs, Attribute rhs) {
26   return builder.getIntegerAttr(res.getType(),
27                                 lhs.cast<IntegerAttr>().getInt() +
28                                     rhs.cast<IntegerAttr>().getInt());
29 }
30 
31 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
32                                    Attribute lhs, Attribute rhs) {
33   return builder.getIntegerAttr(res.getType(),
34                                 lhs.cast<IntegerAttr>().getInt() -
35                                     rhs.cast<IntegerAttr>().getInt());
36 }
37 
38 /// Invert an integer comparison predicate.
39 static arith::CmpIPredicate invertPredicate(arith::CmpIPredicate pred) {
40   switch (pred) {
41   case arith::CmpIPredicate::eq:
42     return arith::CmpIPredicate::ne;
43   case arith::CmpIPredicate::ne:
44     return arith::CmpIPredicate::eq;
45   case arith::CmpIPredicate::slt:
46     return arith::CmpIPredicate::sge;
47   case arith::CmpIPredicate::sle:
48     return arith::CmpIPredicate::sgt;
49   case arith::CmpIPredicate::sgt:
50     return arith::CmpIPredicate::sle;
51   case arith::CmpIPredicate::sge:
52     return arith::CmpIPredicate::slt;
53   case arith::CmpIPredicate::ult:
54     return arith::CmpIPredicate::uge;
55   case arith::CmpIPredicate::ule:
56     return arith::CmpIPredicate::ugt;
57   case arith::CmpIPredicate::ugt:
58     return arith::CmpIPredicate::ule;
59   case arith::CmpIPredicate::uge:
60     return arith::CmpIPredicate::ult;
61   }
62   llvm_unreachable("unknown cmpi predicate kind");
63 }
64 
65 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
66   return arith::CmpIPredicateAttr::get(pred.getContext(),
67                                        invertPredicate(pred.getValue()));
68 }
69 
70 //===----------------------------------------------------------------------===//
71 // TableGen'd canonicalization patterns
72 //===----------------------------------------------------------------------===//
73 
74 namespace {
75 #include "ArithmeticCanonicalization.inc"
76 } // end anonymous namespace
77 
78 //===----------------------------------------------------------------------===//
79 // ConstantOp
80 //===----------------------------------------------------------------------===//
81 
82 void arith::ConstantOp::getAsmResultNames(
83     function_ref<void(Value, StringRef)> setNameFn) {
84   auto type = getType();
85   if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
86     auto intType = type.dyn_cast<IntegerType>();
87 
88     // Sugar i1 constants with 'true' and 'false'.
89     if (intType && intType.getWidth() == 1)
90       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
91 
92     // Otherwise, build a compex name with the value and type.
93     SmallString<32> specialNameBuffer;
94     llvm::raw_svector_ostream specialName(specialNameBuffer);
95     specialName << 'c' << intCst.getInt();
96     if (intType)
97       specialName << '_' << type;
98     setNameFn(getResult(), specialName.str());
99   } else {
100     setNameFn(getResult(), "cst");
101   }
102 }
103 
104 /// TODO: disallow arith.constant to return anything other than signless integer
105 /// or float like.
106 static LogicalResult verify(arith::ConstantOp op) {
107   auto type = op.getType();
108   // The value's type must match the return type.
109   if (op.getValue().getType() != type) {
110     return op.emitOpError() << "value type " << op.getValue().getType()
111                             << " must match return type: " << type;
112   }
113   // Integer values must be signless.
114   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
115     return op.emitOpError("integer return type must be signless");
116   // Any float or elements attribute are acceptable.
117   if (!op.getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
118     return op.emitOpError(
119         "value must be an integer, float, or elements attribute");
120   }
121   return success();
122 }
123 
124 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
125   // The value's type must be the same as the provided type.
126   if (value.getType() != type)
127     return false;
128   // Integer values must be signless.
129   if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
130     return false;
131   // Integer, float, and element attributes are buildable.
132   return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
133 }
134 
135 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
136   return getValue();
137 }
138 
139 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
140                                  int64_t value, unsigned width) {
141   auto type = builder.getIntegerType(width);
142   arith::ConstantOp::build(builder, result, type,
143                            builder.getIntegerAttr(type, value));
144 }
145 
146 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
147                                  int64_t value, Type type) {
148   assert(type.isSignlessInteger() &&
149          "ConstantIntOp can only have signless integer type values");
150   arith::ConstantOp::build(builder, result, type,
151                            builder.getIntegerAttr(type, value));
152 }
153 
154 bool arith::ConstantIntOp::classof(Operation *op) {
155   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
156     return constOp.getType().isSignlessInteger();
157   return false;
158 }
159 
160 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
161                                    const APFloat &value, FloatType type) {
162   arith::ConstantOp::build(builder, result, type,
163                            builder.getFloatAttr(type, value));
164 }
165 
166 bool arith::ConstantFloatOp::classof(Operation *op) {
167   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
168     return constOp.getType().isa<FloatType>();
169   return false;
170 }
171 
172 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
173                                    int64_t value) {
174   arith::ConstantOp::build(builder, result, builder.getIndexType(),
175                            builder.getIndexAttr(value));
176 }
177 
178 bool arith::ConstantIndexOp::classof(Operation *op) {
179   if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
180     return constOp.getType().isIndex();
181   return false;
182 }
183 
184 //===----------------------------------------------------------------------===//
185 // AddIOp
186 //===----------------------------------------------------------------------===//
187 
188 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
189   // addi(x, 0) -> x
190   if (matchPattern(getRhs(), m_Zero()))
191     return getLhs();
192 
193   return constFoldBinaryOp<IntegerAttr>(operands,
194                                         [](APInt a, APInt b) { return a + b; });
195 }
196 
197 void arith::AddIOp::getCanonicalizationPatterns(
198     OwningRewritePatternList &patterns, MLIRContext *context) {
199   patterns.insert<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
200       context);
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // SubIOp
205 //===----------------------------------------------------------------------===//
206 
207 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
208   // subi(x,x) -> 0
209   if (getOperand(0) == getOperand(1))
210     return Builder(getContext()).getZeroAttr(getType());
211   // subi(x,0) -> x
212   if (matchPattern(getRhs(), m_Zero()))
213     return getLhs();
214 
215   return constFoldBinaryOp<IntegerAttr>(operands,
216                                         [](APInt a, APInt b) { return a - b; });
217 }
218 
219 void arith::SubIOp::getCanonicalizationPatterns(
220     OwningRewritePatternList &patterns, MLIRContext *context) {
221   patterns.insert<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
222                   SubIRHSSubConstantLHS, SubILHSSubConstantRHS,
223                   SubILHSSubConstantLHS>(context);
224 }
225 
226 //===----------------------------------------------------------------------===//
227 // MulIOp
228 //===----------------------------------------------------------------------===//
229 
230 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
231   // muli(x, 0) -> 0
232   if (matchPattern(getRhs(), m_Zero()))
233     return getRhs();
234   // muli(x, 1) -> x
235   if (matchPattern(getRhs(), m_One()))
236     return getOperand(0);
237   // TODO: Handle the overflow case.
238 
239   // default folder
240   return constFoldBinaryOp<IntegerAttr>(operands,
241                                         [](APInt a, APInt b) { return a * b; });
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // DivUIOp
246 //===----------------------------------------------------------------------===//
247 
248 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
249   // Don't fold if it would require a division by zero.
250   bool div0 = false;
251   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
252     if (div0 || !b) {
253       div0 = true;
254       return a;
255     }
256     return a.udiv(b);
257   });
258 
259   // Fold out division by one. Assumes all tensors of all ones are splats.
260   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
261     if (rhs.getValue() == 1)
262       return getLhs();
263   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
264     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
265       return getLhs();
266   }
267 
268   return div0 ? Attribute() : result;
269 }
270 
271 //===----------------------------------------------------------------------===//
272 // DivSIOp
273 //===----------------------------------------------------------------------===//
274 
275 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
276   // Don't fold if it would overflow or if it requires a division by zero.
277   bool overflowOrDiv0 = false;
278   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
279     if (overflowOrDiv0 || !b) {
280       overflowOrDiv0 = true;
281       return a;
282     }
283     return a.sdiv_ov(b, overflowOrDiv0);
284   });
285 
286   // Fold out division by one. Assumes all tensors of all ones are splats.
287   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
288     if (rhs.getValue() == 1)
289       return getLhs();
290   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
291     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
292       return getLhs();
293   }
294 
295   return overflowOrDiv0 ? Attribute() : result;
296 }
297 
298 //===----------------------------------------------------------------------===//
299 // Ceil and floor division folding helpers
300 //===----------------------------------------------------------------------===//
301 
302 static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) {
303   // Returns (a-1)/b + 1
304   APInt one(a.getBitWidth(), 1, true); // Signed value 1.
305   APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
306   return val.sadd_ov(one, overflow);
307 }
308 
309 //===----------------------------------------------------------------------===//
310 // CeilDivSIOp
311 //===----------------------------------------------------------------------===//
312 
313 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
314   // Don't fold if it would overflow or if it requires a division by zero.
315   bool overflowOrDiv0 = false;
316   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
317     if (overflowOrDiv0 || !b) {
318       overflowOrDiv0 = true;
319       return a;
320     }
321     unsigned bits = a.getBitWidth();
322     APInt zero = APInt::getZero(bits);
323     if (a.sgt(zero) && b.sgt(zero)) {
324       // Both positive, return ceil(a, b).
325       return signedCeilNonnegInputs(a, b, overflowOrDiv0);
326     }
327     if (a.slt(zero) && b.slt(zero)) {
328       // Both negative, return ceil(-a, -b).
329       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
330       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
331       return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
332     }
333     if (a.slt(zero) && b.sgt(zero)) {
334       // A is negative, b is positive, return - ( -a / b).
335       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
336       APInt div = posA.sdiv_ov(b, overflowOrDiv0);
337       return zero.ssub_ov(div, overflowOrDiv0);
338     }
339     // A is positive (or zero), b is negative, return - (a / -b).
340     APInt posB = zero.ssub_ov(b, overflowOrDiv0);
341     APInt div = a.sdiv_ov(posB, overflowOrDiv0);
342     return zero.ssub_ov(div, overflowOrDiv0);
343   });
344 
345   // Fold out floor division by one. Assumes all tensors of all ones are
346   // splats.
347   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
348     if (rhs.getValue() == 1)
349       return getLhs();
350   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
351     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
352       return getLhs();
353   }
354 
355   return overflowOrDiv0 ? Attribute() : result;
356 }
357 
358 //===----------------------------------------------------------------------===//
359 // FloorDivSIOp
360 //===----------------------------------------------------------------------===//
361 
362 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
363   // Don't fold if it would overflow or if it requires a division by zero.
364   bool overflowOrDiv0 = false;
365   auto result = constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, APInt b) {
366     if (overflowOrDiv0 || !b) {
367       overflowOrDiv0 = true;
368       return a;
369     }
370     unsigned bits = a.getBitWidth();
371     APInt zero = APInt::getZero(bits);
372     if (a.sge(zero) && b.sgt(zero)) {
373       // Both positive (or a is zero), return a / b.
374       return a.sdiv_ov(b, overflowOrDiv0);
375     }
376     if (a.sle(zero) && b.slt(zero)) {
377       // Both negative (or a is zero), return -a / -b.
378       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
379       APInt posB = zero.ssub_ov(b, overflowOrDiv0);
380       return posA.sdiv_ov(posB, overflowOrDiv0);
381     }
382     if (a.slt(zero) && b.sgt(zero)) {
383       // A is negative, b is positive, return - ceil(-a, b).
384       APInt posA = zero.ssub_ov(a, overflowOrDiv0);
385       APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
386       return zero.ssub_ov(ceil, overflowOrDiv0);
387     }
388     // A is positive, b is negative, return - ceil(a, -b).
389     APInt posB = zero.ssub_ov(b, overflowOrDiv0);
390     APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
391     return zero.ssub_ov(ceil, overflowOrDiv0);
392   });
393 
394   // Fold out floor division by one. Assumes all tensors of all ones are
395   // splats.
396   if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>()) {
397     if (rhs.getValue() == 1)
398       return getLhs();
399   } else if (auto rhs = operands[1].dyn_cast_or_null<SplatElementsAttr>()) {
400     if (rhs.getSplatValue<IntegerAttr>().getValue() == 1)
401       return getLhs();
402   }
403 
404   return overflowOrDiv0 ? Attribute() : result;
405 }
406 
407 //===----------------------------------------------------------------------===//
408 // RemUIOp
409 //===----------------------------------------------------------------------===//
410 
411 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
412   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
413   if (!rhs)
414     return {};
415   auto rhsValue = rhs.getValue();
416 
417   // x % 1 = 0
418   if (rhsValue.isOneValue())
419     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
420 
421   // Don't fold if it requires division by zero.
422   if (rhsValue.isNullValue())
423     return {};
424 
425   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
426   if (!lhs)
427     return {};
428   return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue));
429 }
430 
431 //===----------------------------------------------------------------------===//
432 // RemSIOp
433 //===----------------------------------------------------------------------===//
434 
435 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
436   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
437   if (!rhs)
438     return {};
439   auto rhsValue = rhs.getValue();
440 
441   // x % 1 = 0
442   if (rhsValue.isOneValue())
443     return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0));
444 
445   // Don't fold if it requires division by zero.
446   if (rhsValue.isNullValue())
447     return {};
448 
449   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
450   if (!lhs)
451     return {};
452   return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue));
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // AndIOp
457 //===----------------------------------------------------------------------===//
458 
459 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
460   /// and(x, 0) -> 0
461   if (matchPattern(getRhs(), m_Zero()))
462     return getRhs();
463   /// and(x, allOnes) -> x
464   APInt intValue;
465   if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
466     return getLhs();
467   /// and(x, x) -> x
468   if (getLhs() == getRhs())
469     return getRhs();
470 
471   return constFoldBinaryOp<IntegerAttr>(operands,
472                                         [](APInt a, APInt b) { return a & b; });
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // OrIOp
477 //===----------------------------------------------------------------------===//
478 
479 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
480   /// or(x, 0) -> x
481   if (matchPattern(getRhs(), m_Zero()))
482     return getLhs();
483   /// or(x, x) -> x
484   if (getLhs() == getRhs())
485     return getRhs();
486   /// or(x, <all ones>) -> <all ones>
487   if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
488     if (rhsAttr.getValue().isAllOnes())
489       return rhsAttr;
490 
491   return constFoldBinaryOp<IntegerAttr>(operands,
492                                         [](APInt a, APInt b) { return a | b; });
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // XOrIOp
497 //===----------------------------------------------------------------------===//
498 
499 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
500   /// xor(x, 0) -> x
501   if (matchPattern(getRhs(), m_Zero()))
502     return getLhs();
503   /// xor(x, x) -> 0
504   if (getLhs() == getRhs())
505     return Builder(getContext()).getZeroAttr(getType());
506 
507   return constFoldBinaryOp<IntegerAttr>(operands,
508                                         [](APInt a, APInt b) { return a ^ b; });
509 }
510 
511 void arith::XOrIOp::getCanonicalizationPatterns(
512     OwningRewritePatternList &patterns, MLIRContext *context) {
513   patterns.insert<XOrINotCmpI>(context);
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // AddFOp
518 //===----------------------------------------------------------------------===//
519 
520 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
521   return constFoldBinaryOp<FloatAttr>(
522       operands, [](APFloat a, APFloat b) { return a + b; });
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // SubFOp
527 //===----------------------------------------------------------------------===//
528 
529 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
530   return constFoldBinaryOp<FloatAttr>(
531       operands, [](APFloat a, APFloat b) { return a - b; });
532 }
533 
534 //===----------------------------------------------------------------------===//
535 // MulFOp
536 //===----------------------------------------------------------------------===//
537 
538 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
539   return constFoldBinaryOp<FloatAttr>(
540       operands, [](APFloat a, APFloat b) { return a * b; });
541 }
542 
543 //===----------------------------------------------------------------------===//
544 // DivFOp
545 //===----------------------------------------------------------------------===//
546 
547 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
548   return constFoldBinaryOp<FloatAttr>(
549       operands, [](APFloat a, APFloat b) { return a / b; });
550 }
551 
552 //===----------------------------------------------------------------------===//
553 // Utility functions for verifying cast ops
554 //===----------------------------------------------------------------------===//
555 
556 template <typename... Types>
557 using type_list = std::tuple<Types...> *;
558 
559 /// Returns a non-null type only if the provided type is one of the allowed
560 /// types or one of the allowed shaped types of the allowed types. Returns the
561 /// element type if a valid shaped type is provided.
562 template <typename... ShapedTypes, typename... ElementTypes>
563 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
564                               type_list<ElementTypes...>) {
565   if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
566     return {};
567 
568   auto underlyingType = getElementTypeOrSelf(type);
569   if (!underlyingType.isa<ElementTypes...>())
570     return {};
571 
572   return underlyingType;
573 }
574 
575 /// Get allowed underlying types for vectors and tensors.
576 template <typename... ElementTypes>
577 static Type getTypeIfLike(Type type) {
578   return getUnderlyingType(type, type_list<VectorType, TensorType>(),
579                            type_list<ElementTypes...>());
580 }
581 
582 /// Get allowed underlying types for vectors, tensors, and memrefs.
583 template <typename... ElementTypes>
584 static Type getTypeIfLikeOrMemRef(Type type) {
585   return getUnderlyingType(type,
586                            type_list<VectorType, TensorType, MemRefType>(),
587                            type_list<ElementTypes...>());
588 }
589 
590 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
591   return inputs.size() == 1 && outputs.size() == 1 &&
592          succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
593 }
594 
595 //===----------------------------------------------------------------------===//
596 // Verifiers for integer and floating point extension/truncation ops
597 //===----------------------------------------------------------------------===//
598 
599 // Extend ops can only extend to a wider type.
600 template <typename ValType, typename Op>
601 static LogicalResult verifyExtOp(Op op) {
602   Type srcType = getElementTypeOrSelf(op.getIn().getType());
603   Type dstType = getElementTypeOrSelf(op.getType());
604 
605   if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
606     return op.emitError("result type ")
607            << dstType << " must be wider than operand type " << srcType;
608 
609   return success();
610 }
611 
612 // Truncate ops can only truncate to a shorter type.
613 template <typename ValType, typename Op>
614 static LogicalResult verifyTruncateOp(Op op) {
615   Type srcType = getElementTypeOrSelf(op.getIn().getType());
616   Type dstType = getElementTypeOrSelf(op.getType());
617 
618   if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
619     return op.emitError("result type ")
620            << dstType << " must be shorter than operand type " << srcType;
621 
622   return success();
623 }
624 
625 /// Validate a cast that changes the width of a type.
626 template <template <typename> class WidthComparator, typename... ElementTypes>
627 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
628   if (!areValidCastInputsAndOutputs(inputs, outputs))
629     return false;
630 
631   auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
632   auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
633   if (!srcType || !dstType)
634     return false;
635 
636   return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
637                                      srcType.getIntOrFloatBitWidth());
638 }
639 
640 //===----------------------------------------------------------------------===//
641 // ExtUIOp
642 //===----------------------------------------------------------------------===//
643 
644 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
645   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
646     return IntegerAttr::get(
647         getType(), lhs.getValue().zext(getType().getIntOrFloatBitWidth()));
648 
649   return {};
650 }
651 
652 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
653   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
654 }
655 
656 //===----------------------------------------------------------------------===//
657 // ExtSIOp
658 //===----------------------------------------------------------------------===//
659 
660 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
661   if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
662     return IntegerAttr::get(
663         getType(), lhs.getValue().sext(getType().getIntOrFloatBitWidth()));
664 
665   return {};
666 }
667 
668 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
669   return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // ExtFOp
674 //===----------------------------------------------------------------------===//
675 
676 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
677   return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
678 }
679 
680 //===----------------------------------------------------------------------===//
681 // TruncIOp
682 //===----------------------------------------------------------------------===//
683 
684 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
685   // trunci(zexti(a)) -> a
686   // trunci(sexti(a)) -> a
687   if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
688       matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
689     return getOperand().getDefiningOp()->getOperand(0);
690 
691   assert(operands.size() == 1 && "unary operation takes one operand");
692 
693   if (!operands[0])
694     return {};
695 
696   if (auto lhs = operands[0].dyn_cast<IntegerAttr>()) {
697     return IntegerAttr::get(
698         getType(), lhs.getValue().trunc(getType().getIntOrFloatBitWidth()));
699   }
700 
701   return {};
702 }
703 
704 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
705   return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
706 }
707 
708 //===----------------------------------------------------------------------===//
709 // TruncFOp
710 //===----------------------------------------------------------------------===//
711 
712 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
713 /// can be represented without precision loss or rounding.
714 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
715   assert(operands.size() == 1 && "unary operation takes one operand");
716 
717   auto constOperand = operands.front();
718   if (!constOperand || !constOperand.isa<FloatAttr>())
719     return {};
720 
721   // Convert to target type via 'double'.
722   double sourceValue =
723       constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
724   auto targetAttr = FloatAttr::get(getType(), sourceValue);
725 
726   // Propagate if constant's value does not change after truncation.
727   if (sourceValue == targetAttr.getValue().convertToDouble())
728     return targetAttr;
729 
730   return {};
731 }
732 
733 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
734   return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // Verifiers for casts between integers and floats.
739 //===----------------------------------------------------------------------===//
740 
741 template <typename From, typename To>
742 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
743   if (!areValidCastInputsAndOutputs(inputs, outputs))
744     return false;
745 
746   auto srcType = getTypeIfLike<From>(inputs.front());
747   auto dstType = getTypeIfLike<To>(outputs.back());
748 
749   return srcType && dstType;
750 }
751 
752 //===----------------------------------------------------------------------===//
753 // UIToFPOp
754 //===----------------------------------------------------------------------===//
755 
756 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
757   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // SIToFPOp
762 //===----------------------------------------------------------------------===//
763 
764 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
765   return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // FPToUIOp
770 //===----------------------------------------------------------------------===//
771 
772 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
773   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
774 }
775 
776 //===----------------------------------------------------------------------===//
777 // FPToSIOp
778 //===----------------------------------------------------------------------===//
779 
780 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
781   return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
782 }
783 
784 //===----------------------------------------------------------------------===//
785 // IndexCastOp
786 //===----------------------------------------------------------------------===//
787 
788 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
789                                            TypeRange outputs) {
790   if (!areValidCastInputsAndOutputs(inputs, outputs))
791     return false;
792 
793   auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
794   auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
795   if (!srcType || !dstType)
796     return false;
797 
798   return (srcType.isIndex() && dstType.isSignlessInteger()) ||
799          (srcType.isSignlessInteger() && dstType.isIndex());
800 }
801 
802 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
803   // index_cast(constant) -> constant
804   // A little hack because we go through int. Otherwise, the size of the
805   // constant might need to change.
806   if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
807     return IntegerAttr::get(getType(), value.getInt());
808 
809   return {};
810 }
811 
812 void arith::IndexCastOp::getCanonicalizationPatterns(
813     OwningRewritePatternList &patterns, MLIRContext *context) {
814   patterns.insert<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
815 }
816 
817 //===----------------------------------------------------------------------===//
818 // BitcastOp
819 //===----------------------------------------------------------------------===//
820 
821 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
822   if (!areValidCastInputsAndOutputs(inputs, outputs))
823     return false;
824 
825   auto srcType =
826       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
827   auto dstType =
828       getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
829   if (!srcType || !dstType)
830     return false;
831 
832   return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
833 }
834 
835 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
836   assert(operands.size() == 1 && "bitcast op expects 1 operand");
837 
838   auto resType = getType();
839   auto operand = operands[0];
840   if (!operand)
841     return {};
842 
843   /// Bitcast dense elements.
844   if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
845     return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
846   /// Other shaped types unhandled.
847   if (resType.isa<ShapedType>())
848     return {};
849 
850   /// Bitcast integer or float to integer or float.
851   APInt bits = operand.isa<FloatAttr>()
852                    ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
853                    : operand.cast<IntegerAttr>().getValue();
854 
855   if (auto resFloatType = resType.dyn_cast<FloatType>())
856     return FloatAttr::get(resType,
857                           APFloat(resFloatType.getFloatSemantics(), bits));
858   return IntegerAttr::get(resType, bits);
859 }
860 
861 void arith::BitcastOp::getCanonicalizationPatterns(
862     OwningRewritePatternList &patterns, MLIRContext *context) {
863   patterns.insert<BitcastOfBitcast>(context);
864 }
865 
866 //===----------------------------------------------------------------------===//
867 // Helpers for compare ops
868 //===----------------------------------------------------------------------===//
869 
870 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
871 static Type getI1SameShape(Type type) {
872   auto i1Type = IntegerType::get(type.getContext(), 1);
873   if (auto tensorType = type.dyn_cast<RankedTensorType>())
874     return RankedTensorType::get(tensorType.getShape(), i1Type);
875   if (type.isa<UnrankedTensorType>())
876     return UnrankedTensorType::get(i1Type);
877   if (auto vectorType = type.dyn_cast<VectorType>())
878     return VectorType::get(vectorType.getShape(), i1Type);
879   return i1Type;
880 }
881 
882 //===----------------------------------------------------------------------===//
883 // CmpIOp
884 //===----------------------------------------------------------------------===//
885 
886 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
887 /// comparison predicates.
888 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
889                                     const APInt &lhs, const APInt &rhs) {
890   switch (predicate) {
891   case arith::CmpIPredicate::eq:
892     return lhs.eq(rhs);
893   case arith::CmpIPredicate::ne:
894     return lhs.ne(rhs);
895   case arith::CmpIPredicate::slt:
896     return lhs.slt(rhs);
897   case arith::CmpIPredicate::sle:
898     return lhs.sle(rhs);
899   case arith::CmpIPredicate::sgt:
900     return lhs.sgt(rhs);
901   case arith::CmpIPredicate::sge:
902     return lhs.sge(rhs);
903   case arith::CmpIPredicate::ult:
904     return lhs.ult(rhs);
905   case arith::CmpIPredicate::ule:
906     return lhs.ule(rhs);
907   case arith::CmpIPredicate::ugt:
908     return lhs.ugt(rhs);
909   case arith::CmpIPredicate::uge:
910     return lhs.uge(rhs);
911   }
912   llvm_unreachable("unknown cmpi predicate kind");
913 }
914 
915 /// Returns true if the predicate is true for two equal operands.
916 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
917   switch (predicate) {
918   case arith::CmpIPredicate::eq:
919   case arith::CmpIPredicate::sle:
920   case arith::CmpIPredicate::sge:
921   case arith::CmpIPredicate::ule:
922   case arith::CmpIPredicate::uge:
923     return true;
924   case arith::CmpIPredicate::ne:
925   case arith::CmpIPredicate::slt:
926   case arith::CmpIPredicate::sgt:
927   case arith::CmpIPredicate::ult:
928   case arith::CmpIPredicate::ugt:
929     return false;
930   }
931   llvm_unreachable("unknown cmpi predicate kind");
932 }
933 
934 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
935   assert(operands.size() == 2 && "cmpi takes two operands");
936 
937   // cmpi(pred, x, x)
938   if (getLhs() == getRhs()) {
939     auto val = applyCmpPredicateToEqualOperands(getPredicate());
940     return BoolAttr::get(getContext(), val);
941   }
942 
943   auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
944   auto rhs = operands.back().dyn_cast_or_null<IntegerAttr>();
945   if (!lhs || !rhs)
946     return {};
947 
948   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
949   return BoolAttr::get(getContext(), val);
950 }
951 
952 //===----------------------------------------------------------------------===//
953 // CmpFOp
954 //===----------------------------------------------------------------------===//
955 
956 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
957 /// comparison predicates.
958 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
959                                     const APFloat &lhs, const APFloat &rhs) {
960   auto cmpResult = lhs.compare(rhs);
961   switch (predicate) {
962   case arith::CmpFPredicate::AlwaysFalse:
963     return false;
964   case arith::CmpFPredicate::OEQ:
965     return cmpResult == APFloat::cmpEqual;
966   case arith::CmpFPredicate::OGT:
967     return cmpResult == APFloat::cmpGreaterThan;
968   case arith::CmpFPredicate::OGE:
969     return cmpResult == APFloat::cmpGreaterThan ||
970            cmpResult == APFloat::cmpEqual;
971   case arith::CmpFPredicate::OLT:
972     return cmpResult == APFloat::cmpLessThan;
973   case arith::CmpFPredicate::OLE:
974     return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
975   case arith::CmpFPredicate::ONE:
976     return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
977   case arith::CmpFPredicate::ORD:
978     return cmpResult != APFloat::cmpUnordered;
979   case arith::CmpFPredicate::UEQ:
980     return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
981   case arith::CmpFPredicate::UGT:
982     return cmpResult == APFloat::cmpUnordered ||
983            cmpResult == APFloat::cmpGreaterThan;
984   case arith::CmpFPredicate::UGE:
985     return cmpResult == APFloat::cmpUnordered ||
986            cmpResult == APFloat::cmpGreaterThan ||
987            cmpResult == APFloat::cmpEqual;
988   case arith::CmpFPredicate::ULT:
989     return cmpResult == APFloat::cmpUnordered ||
990            cmpResult == APFloat::cmpLessThan;
991   case arith::CmpFPredicate::ULE:
992     return cmpResult == APFloat::cmpUnordered ||
993            cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
994   case arith::CmpFPredicate::UNE:
995     return cmpResult != APFloat::cmpEqual;
996   case arith::CmpFPredicate::UNO:
997     return cmpResult == APFloat::cmpUnordered;
998   case arith::CmpFPredicate::AlwaysTrue:
999     return true;
1000   }
1001   llvm_unreachable("unknown cmpf predicate kind");
1002 }
1003 
1004 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1005   assert(operands.size() == 2 && "cmpf takes two operands");
1006 
1007   auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1008   auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1009 
1010   if (!lhs || !rhs)
1011     return {};
1012 
1013   auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1014   return BoolAttr::get(getContext(), val);
1015 }
1016 
1017 //===----------------------------------------------------------------------===//
1018 // TableGen'd op method definitions
1019 //===----------------------------------------------------------------------===//
1020 
1021 #define GET_OP_CLASSES
1022 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
1023 
1024 //===----------------------------------------------------------------------===//
1025 // TableGen'd enum attribute definitions
1026 //===----------------------------------------------------------------------===//
1027 
1028 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
1029