1 //===- ArithmeticOps.cpp - MLIR Arithmetic dialect ops implementation -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <utility>
10
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/CommonFolders.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/IR/TypeUtilities.h"
18 #include "llvm/ADT/SmallString.h"
19
20 #include "llvm/ADT/APSInt.h"
21
22 using namespace mlir;
23 using namespace mlir::arith;
24
25 //===----------------------------------------------------------------------===//
26 // Pattern helpers
27 //===----------------------------------------------------------------------===//
28
addIntegerAttrs(PatternRewriter & builder,Value res,Attribute lhs,Attribute rhs)29 static IntegerAttr addIntegerAttrs(PatternRewriter &builder, Value res,
30 Attribute lhs, Attribute rhs) {
31 return builder.getIntegerAttr(res.getType(),
32 lhs.cast<IntegerAttr>().getInt() +
33 rhs.cast<IntegerAttr>().getInt());
34 }
35
subIntegerAttrs(PatternRewriter & builder,Value res,Attribute lhs,Attribute rhs)36 static IntegerAttr subIntegerAttrs(PatternRewriter &builder, Value res,
37 Attribute lhs, Attribute rhs) {
38 return builder.getIntegerAttr(res.getType(),
39 lhs.cast<IntegerAttr>().getInt() -
40 rhs.cast<IntegerAttr>().getInt());
41 }
42
43 /// Invert an integer comparison predicate.
invertPredicate(arith::CmpIPredicate pred)44 arith::CmpIPredicate arith::invertPredicate(arith::CmpIPredicate pred) {
45 switch (pred) {
46 case arith::CmpIPredicate::eq:
47 return arith::CmpIPredicate::ne;
48 case arith::CmpIPredicate::ne:
49 return arith::CmpIPredicate::eq;
50 case arith::CmpIPredicate::slt:
51 return arith::CmpIPredicate::sge;
52 case arith::CmpIPredicate::sle:
53 return arith::CmpIPredicate::sgt;
54 case arith::CmpIPredicate::sgt:
55 return arith::CmpIPredicate::sle;
56 case arith::CmpIPredicate::sge:
57 return arith::CmpIPredicate::slt;
58 case arith::CmpIPredicate::ult:
59 return arith::CmpIPredicate::uge;
60 case arith::CmpIPredicate::ule:
61 return arith::CmpIPredicate::ugt;
62 case arith::CmpIPredicate::ugt:
63 return arith::CmpIPredicate::ule;
64 case arith::CmpIPredicate::uge:
65 return arith::CmpIPredicate::ult;
66 }
67 llvm_unreachable("unknown cmpi predicate kind");
68 }
69
invertPredicate(arith::CmpIPredicateAttr pred)70 static arith::CmpIPredicateAttr invertPredicate(arith::CmpIPredicateAttr pred) {
71 return arith::CmpIPredicateAttr::get(pred.getContext(),
72 invertPredicate(pred.getValue()));
73 }
74
75 //===----------------------------------------------------------------------===//
76 // TableGen'd canonicalization patterns
77 //===----------------------------------------------------------------------===//
78
79 namespace {
80 #include "ArithmeticCanonicalization.inc"
81 } // namespace
82
83 //===----------------------------------------------------------------------===//
84 // ConstantOp
85 //===----------------------------------------------------------------------===//
86
getAsmResultNames(function_ref<void (Value,StringRef)> setNameFn)87 void arith::ConstantOp::getAsmResultNames(
88 function_ref<void(Value, StringRef)> setNameFn) {
89 auto type = getType();
90 if (auto intCst = getValue().dyn_cast<IntegerAttr>()) {
91 auto intType = type.dyn_cast<IntegerType>();
92
93 // Sugar i1 constants with 'true' and 'false'.
94 if (intType && intType.getWidth() == 1)
95 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
96
97 // Otherwise, build a complex name with the value and type.
98 SmallString<32> specialNameBuffer;
99 llvm::raw_svector_ostream specialName(specialNameBuffer);
100 specialName << 'c' << intCst.getValue();
101 if (intType)
102 specialName << '_' << type;
103 setNameFn(getResult(), specialName.str());
104 } else {
105 setNameFn(getResult(), "cst");
106 }
107 }
108
109 /// TODO: disallow arith.constant to return anything other than signless integer
110 /// or float like.
verify()111 LogicalResult arith::ConstantOp::verify() {
112 auto type = getType();
113 // The value's type must match the return type.
114 if (getValue().getType() != type) {
115 return emitOpError() << "value type " << getValue().getType()
116 << " must match return type: " << type;
117 }
118 // Integer values must be signless.
119 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
120 return emitOpError("integer return type must be signless");
121 // Any float or elements attribute are acceptable.
122 if (!getValue().isa<IntegerAttr, FloatAttr, ElementsAttr>()) {
123 return emitOpError(
124 "value must be an integer, float, or elements attribute");
125 }
126 return success();
127 }
128
isBuildableWith(Attribute value,Type type)129 bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) {
130 // The value's type must be the same as the provided type.
131 if (value.getType() != type)
132 return false;
133 // Integer values must be signless.
134 if (type.isa<IntegerType>() && !type.cast<IntegerType>().isSignless())
135 return false;
136 // Integer, float, and element attributes are buildable.
137 return value.isa<IntegerAttr, FloatAttr, ElementsAttr>();
138 }
139
fold(ArrayRef<Attribute> operands)140 OpFoldResult arith::ConstantOp::fold(ArrayRef<Attribute> operands) {
141 return getValue();
142 }
143
build(OpBuilder & builder,OperationState & result,int64_t value,unsigned width)144 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
145 int64_t value, unsigned width) {
146 auto type = builder.getIntegerType(width);
147 arith::ConstantOp::build(builder, result, type,
148 builder.getIntegerAttr(type, value));
149 }
150
build(OpBuilder & builder,OperationState & result,int64_t value,Type type)151 void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result,
152 int64_t value, Type type) {
153 assert(type.isSignlessInteger() &&
154 "ConstantIntOp can only have signless integer type values");
155 arith::ConstantOp::build(builder, result, type,
156 builder.getIntegerAttr(type, value));
157 }
158
classof(Operation * op)159 bool arith::ConstantIntOp::classof(Operation *op) {
160 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
161 return constOp.getType().isSignlessInteger();
162 return false;
163 }
164
build(OpBuilder & builder,OperationState & result,const APFloat & value,FloatType type)165 void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result,
166 const APFloat &value, FloatType type) {
167 arith::ConstantOp::build(builder, result, type,
168 builder.getFloatAttr(type, value));
169 }
170
classof(Operation * op)171 bool arith::ConstantFloatOp::classof(Operation *op) {
172 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
173 return constOp.getType().isa<FloatType>();
174 return false;
175 }
176
build(OpBuilder & builder,OperationState & result,int64_t value)177 void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result,
178 int64_t value) {
179 arith::ConstantOp::build(builder, result, builder.getIndexType(),
180 builder.getIndexAttr(value));
181 }
182
classof(Operation * op)183 bool arith::ConstantIndexOp::classof(Operation *op) {
184 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(op))
185 return constOp.getType().isIndex();
186 return false;
187 }
188
189 //===----------------------------------------------------------------------===//
190 // AddIOp
191 //===----------------------------------------------------------------------===//
192
fold(ArrayRef<Attribute> operands)193 OpFoldResult arith::AddIOp::fold(ArrayRef<Attribute> operands) {
194 // addi(x, 0) -> x
195 if (matchPattern(getRhs(), m_Zero()))
196 return getLhs();
197
198 // addi(subi(a, b), b) -> a
199 if (auto sub = getLhs().getDefiningOp<SubIOp>())
200 if (getRhs() == sub.getRhs())
201 return sub.getLhs();
202
203 // addi(b, subi(a, b)) -> a
204 if (auto sub = getRhs().getDefiningOp<SubIOp>())
205 if (getLhs() == sub.getRhs())
206 return sub.getLhs();
207
208 return constFoldBinaryOp<IntegerAttr>(
209 operands, [](APInt a, const APInt &b) { return std::move(a) + b; });
210 }
211
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)212 void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
213 MLIRContext *context) {
214 patterns.add<AddIAddConstant, AddISubConstantRHS, AddISubConstantLHS>(
215 context);
216 }
217
218 //===----------------------------------------------------------------------===//
219 // SubIOp
220 //===----------------------------------------------------------------------===//
221
fold(ArrayRef<Attribute> operands)222 OpFoldResult arith::SubIOp::fold(ArrayRef<Attribute> operands) {
223 // subi(x,x) -> 0
224 if (getOperand(0) == getOperand(1))
225 return Builder(getContext()).getZeroAttr(getType());
226 // subi(x,0) -> x
227 if (matchPattern(getRhs(), m_Zero()))
228 return getLhs();
229
230 return constFoldBinaryOp<IntegerAttr>(
231 operands, [](APInt a, const APInt &b) { return std::move(a) - b; });
232 }
233
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)234 void arith::SubIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
235 MLIRContext *context) {
236 patterns
237 .add<SubIRHSAddConstant, SubILHSAddConstant, SubIRHSSubConstantRHS,
238 SubIRHSSubConstantLHS, SubILHSSubConstantRHS, SubILHSSubConstantLHS>(
239 context);
240 }
241
242 //===----------------------------------------------------------------------===//
243 // MulIOp
244 //===----------------------------------------------------------------------===//
245
fold(ArrayRef<Attribute> operands)246 OpFoldResult arith::MulIOp::fold(ArrayRef<Attribute> operands) {
247 // muli(x, 0) -> 0
248 if (matchPattern(getRhs(), m_Zero()))
249 return getRhs();
250 // muli(x, 1) -> x
251 if (matchPattern(getRhs(), m_One()))
252 return getOperand(0);
253 // TODO: Handle the overflow case.
254
255 // default folder
256 return constFoldBinaryOp<IntegerAttr>(
257 operands, [](const APInt &a, const APInt &b) { return a * b; });
258 }
259
260 //===----------------------------------------------------------------------===//
261 // DivUIOp
262 //===----------------------------------------------------------------------===//
263
fold(ArrayRef<Attribute> operands)264 OpFoldResult arith::DivUIOp::fold(ArrayRef<Attribute> operands) {
265 // divui (x, 1) -> x.
266 if (matchPattern(getRhs(), m_One()))
267 return getLhs();
268
269 // Don't fold if it would require a division by zero.
270 bool div0 = false;
271 auto result =
272 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
273 if (div0 || !b) {
274 div0 = true;
275 return a;
276 }
277 return a.udiv(b);
278 });
279
280 return div0 ? Attribute() : result;
281 }
282
283 //===----------------------------------------------------------------------===//
284 // DivSIOp
285 //===----------------------------------------------------------------------===//
286
fold(ArrayRef<Attribute> operands)287 OpFoldResult arith::DivSIOp::fold(ArrayRef<Attribute> operands) {
288 // divsi (x, 1) -> x.
289 if (matchPattern(getRhs(), m_One()))
290 return getLhs();
291
292 // Don't fold if it would overflow or if it requires a division by zero.
293 bool overflowOrDiv0 = false;
294 auto result =
295 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
296 if (overflowOrDiv0 || !b) {
297 overflowOrDiv0 = true;
298 return a;
299 }
300 return a.sdiv_ov(b, overflowOrDiv0);
301 });
302
303 return overflowOrDiv0 ? Attribute() : result;
304 }
305
306 //===----------------------------------------------------------------------===//
307 // Ceil and floor division folding helpers
308 //===----------------------------------------------------------------------===//
309
signedCeilNonnegInputs(const APInt & a,const APInt & b,bool & overflow)310 static APInt signedCeilNonnegInputs(const APInt &a, const APInt &b,
311 bool &overflow) {
312 // Returns (a-1)/b + 1
313 APInt one(a.getBitWidth(), 1, true); // Signed value 1.
314 APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow);
315 return val.sadd_ov(one, overflow);
316 }
317
318 //===----------------------------------------------------------------------===//
319 // CeilDivUIOp
320 //===----------------------------------------------------------------------===//
321
fold(ArrayRef<Attribute> operands)322 OpFoldResult arith::CeilDivUIOp::fold(ArrayRef<Attribute> operands) {
323 // ceildivui (x, 1) -> x.
324 if (matchPattern(getRhs(), m_One()))
325 return getLhs();
326
327 bool overflowOrDiv0 = false;
328 auto result =
329 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
330 if (overflowOrDiv0 || !b) {
331 overflowOrDiv0 = true;
332 return a;
333 }
334 APInt quotient = a.udiv(b);
335 if (!a.urem(b))
336 return quotient;
337 APInt one(a.getBitWidth(), 1, true);
338 return quotient.uadd_ov(one, overflowOrDiv0);
339 });
340
341 return overflowOrDiv0 ? Attribute() : result;
342 }
343
344 //===----------------------------------------------------------------------===//
345 // CeilDivSIOp
346 //===----------------------------------------------------------------------===//
347
fold(ArrayRef<Attribute> operands)348 OpFoldResult arith::CeilDivSIOp::fold(ArrayRef<Attribute> operands) {
349 // ceildivsi (x, 1) -> x.
350 if (matchPattern(getRhs(), m_One()))
351 return getLhs();
352
353 // Don't fold if it would overflow or if it requires a division by zero.
354 bool overflowOrDiv0 = false;
355 auto result =
356 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
357 if (overflowOrDiv0 || !b) {
358 overflowOrDiv0 = true;
359 return a;
360 }
361 if (!a)
362 return a;
363 // After this point we know that neither a or b are zero.
364 unsigned bits = a.getBitWidth();
365 APInt zero = APInt::getZero(bits);
366 bool aGtZero = a.sgt(zero);
367 bool bGtZero = b.sgt(zero);
368 if (aGtZero && bGtZero) {
369 // Both positive, return ceil(a, b).
370 return signedCeilNonnegInputs(a, b, overflowOrDiv0);
371 }
372 if (!aGtZero && !bGtZero) {
373 // Both negative, return ceil(-a, -b).
374 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
375 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
376 return signedCeilNonnegInputs(posA, posB, overflowOrDiv0);
377 }
378 if (!aGtZero && bGtZero) {
379 // A is negative, b is positive, return - ( -a / b).
380 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
381 APInt div = posA.sdiv_ov(b, overflowOrDiv0);
382 return zero.ssub_ov(div, overflowOrDiv0);
383 }
384 // A is positive, b is negative, return - (a / -b).
385 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
386 APInt div = a.sdiv_ov(posB, overflowOrDiv0);
387 return zero.ssub_ov(div, overflowOrDiv0);
388 });
389
390 return overflowOrDiv0 ? Attribute() : result;
391 }
392
393 //===----------------------------------------------------------------------===//
394 // FloorDivSIOp
395 //===----------------------------------------------------------------------===//
396
fold(ArrayRef<Attribute> operands)397 OpFoldResult arith::FloorDivSIOp::fold(ArrayRef<Attribute> operands) {
398 // floordivsi (x, 1) -> x.
399 if (matchPattern(getRhs(), m_One()))
400 return getLhs();
401
402 // Don't fold if it would overflow or if it requires a division by zero.
403 bool overflowOrDiv0 = false;
404 auto result =
405 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
406 if (overflowOrDiv0 || !b) {
407 overflowOrDiv0 = true;
408 return a;
409 }
410 if (!a)
411 return a;
412 // After this point we know that neither a or b are zero.
413 unsigned bits = a.getBitWidth();
414 APInt zero = APInt::getZero(bits);
415 bool aGtZero = a.sgt(zero);
416 bool bGtZero = b.sgt(zero);
417 if (aGtZero && bGtZero) {
418 // Both positive, return a / b.
419 return a.sdiv_ov(b, overflowOrDiv0);
420 }
421 if (!aGtZero && !bGtZero) {
422 // Both negative, return -a / -b.
423 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
424 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
425 return posA.sdiv_ov(posB, overflowOrDiv0);
426 }
427 if (!aGtZero && bGtZero) {
428 // A is negative, b is positive, return - ceil(-a, b).
429 APInt posA = zero.ssub_ov(a, overflowOrDiv0);
430 APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0);
431 return zero.ssub_ov(ceil, overflowOrDiv0);
432 }
433 // A is positive, b is negative, return - ceil(a, -b).
434 APInt posB = zero.ssub_ov(b, overflowOrDiv0);
435 APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0);
436 return zero.ssub_ov(ceil, overflowOrDiv0);
437 });
438
439 return overflowOrDiv0 ? Attribute() : result;
440 }
441
442 //===----------------------------------------------------------------------===//
443 // RemUIOp
444 //===----------------------------------------------------------------------===//
445
fold(ArrayRef<Attribute> operands)446 OpFoldResult arith::RemUIOp::fold(ArrayRef<Attribute> operands) {
447 // remui (x, 1) -> 0.
448 if (matchPattern(getRhs(), m_One()))
449 return Builder(getContext()).getZeroAttr(getType());
450
451 // Don't fold if it would require a division by zero.
452 bool div0 = false;
453 auto result =
454 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
455 if (div0 || b.isNullValue()) {
456 div0 = true;
457 return a;
458 }
459 return a.urem(b);
460 });
461
462 return div0 ? Attribute() : result;
463 }
464
465 //===----------------------------------------------------------------------===//
466 // RemSIOp
467 //===----------------------------------------------------------------------===//
468
fold(ArrayRef<Attribute> operands)469 OpFoldResult arith::RemSIOp::fold(ArrayRef<Attribute> operands) {
470 // remsi (x, 1) -> 0.
471 if (matchPattern(getRhs(), m_One()))
472 return Builder(getContext()).getZeroAttr(getType());
473
474 // Don't fold if it would require a division by zero.
475 bool div0 = false;
476 auto result =
477 constFoldBinaryOp<IntegerAttr>(operands, [&](APInt a, const APInt &b) {
478 if (div0 || b.isNullValue()) {
479 div0 = true;
480 return a;
481 }
482 return a.srem(b);
483 });
484
485 return div0 ? Attribute() : result;
486 }
487
488 //===----------------------------------------------------------------------===//
489 // AndIOp
490 //===----------------------------------------------------------------------===//
491
fold(ArrayRef<Attribute> operands)492 OpFoldResult arith::AndIOp::fold(ArrayRef<Attribute> operands) {
493 /// and(x, 0) -> 0
494 if (matchPattern(getRhs(), m_Zero()))
495 return getRhs();
496 /// and(x, allOnes) -> x
497 APInt intValue;
498 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isAllOnes())
499 return getLhs();
500
501 return constFoldBinaryOp<IntegerAttr>(
502 operands, [](APInt a, const APInt &b) { return std::move(a) & b; });
503 }
504
505 //===----------------------------------------------------------------------===//
506 // OrIOp
507 //===----------------------------------------------------------------------===//
508
fold(ArrayRef<Attribute> operands)509 OpFoldResult arith::OrIOp::fold(ArrayRef<Attribute> operands) {
510 /// or(x, 0) -> x
511 if (matchPattern(getRhs(), m_Zero()))
512 return getLhs();
513 /// or(x, <all ones>) -> <all ones>
514 if (auto rhsAttr = operands[1].dyn_cast_or_null<IntegerAttr>())
515 if (rhsAttr.getValue().isAllOnes())
516 return rhsAttr;
517
518 return constFoldBinaryOp<IntegerAttr>(
519 operands, [](APInt a, const APInt &b) { return std::move(a) | b; });
520 }
521
522 //===----------------------------------------------------------------------===//
523 // XOrIOp
524 //===----------------------------------------------------------------------===//
525
fold(ArrayRef<Attribute> operands)526 OpFoldResult arith::XOrIOp::fold(ArrayRef<Attribute> operands) {
527 /// xor(x, 0) -> x
528 if (matchPattern(getRhs(), m_Zero()))
529 return getLhs();
530 /// xor(x, x) -> 0
531 if (getLhs() == getRhs())
532 return Builder(getContext()).getZeroAttr(getType());
533 /// xor(xor(x, a), a) -> x
534 if (arith::XOrIOp prev = getLhs().getDefiningOp<arith::XOrIOp>())
535 if (prev.getRhs() == getRhs())
536 return prev.getLhs();
537
538 return constFoldBinaryOp<IntegerAttr>(
539 operands, [](APInt a, const APInt &b) { return std::move(a) ^ b; });
540 }
541
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)542 void arith::XOrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
543 MLIRContext *context) {
544 patterns.add<XOrINotCmpI>(context);
545 }
546
547 //===----------------------------------------------------------------------===//
548 // NegFOp
549 //===----------------------------------------------------------------------===//
550
fold(ArrayRef<Attribute> operands)551 OpFoldResult arith::NegFOp::fold(ArrayRef<Attribute> operands) {
552 /// negf(negf(x)) -> x
553 if (auto op = this->getOperand().getDefiningOp<arith::NegFOp>())
554 return op.getOperand();
555 return constFoldUnaryOp<FloatAttr>(operands,
556 [](const APFloat &a) { return -a; });
557 }
558
559 //===----------------------------------------------------------------------===//
560 // AddFOp
561 //===----------------------------------------------------------------------===//
562
fold(ArrayRef<Attribute> operands)563 OpFoldResult arith::AddFOp::fold(ArrayRef<Attribute> operands) {
564 // addf(x, -0) -> x
565 if (matchPattern(getRhs(), m_NegZeroFloat()))
566 return getLhs();
567
568 return constFoldBinaryOp<FloatAttr>(
569 operands, [](const APFloat &a, const APFloat &b) { return a + b; });
570 }
571
572 //===----------------------------------------------------------------------===//
573 // SubFOp
574 //===----------------------------------------------------------------------===//
575
fold(ArrayRef<Attribute> operands)576 OpFoldResult arith::SubFOp::fold(ArrayRef<Attribute> operands) {
577 // subf(x, +0) -> x
578 if (matchPattern(getRhs(), m_PosZeroFloat()))
579 return getLhs();
580
581 return constFoldBinaryOp<FloatAttr>(
582 operands, [](const APFloat &a, const APFloat &b) { return a - b; });
583 }
584
585 //===----------------------------------------------------------------------===//
586 // MaxFOp
587 //===----------------------------------------------------------------------===//
588
fold(ArrayRef<Attribute> operands)589 OpFoldResult arith::MaxFOp::fold(ArrayRef<Attribute> operands) {
590 assert(operands.size() == 2 && "maxf takes two operands");
591
592 // maxf(x,x) -> x
593 if (getLhs() == getRhs())
594 return getRhs();
595
596 // maxf(x, -inf) -> x
597 if (matchPattern(getRhs(), m_NegInfFloat()))
598 return getLhs();
599
600 return constFoldBinaryOp<FloatAttr>(
601 operands,
602 [](const APFloat &a, const APFloat &b) { return llvm::maximum(a, b); });
603 }
604
605 //===----------------------------------------------------------------------===//
606 // MaxSIOp
607 //===----------------------------------------------------------------------===//
608
fold(ArrayRef<Attribute> operands)609 OpFoldResult MaxSIOp::fold(ArrayRef<Attribute> operands) {
610 assert(operands.size() == 2 && "binary operation takes two operands");
611
612 // maxsi(x,x) -> x
613 if (getLhs() == getRhs())
614 return getRhs();
615
616 APInt intValue;
617 // maxsi(x,MAX_INT) -> MAX_INT
618 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
619 intValue.isMaxSignedValue())
620 return getRhs();
621
622 // maxsi(x, MIN_INT) -> x
623 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
624 intValue.isMinSignedValue())
625 return getLhs();
626
627 return constFoldBinaryOp<IntegerAttr>(operands,
628 [](const APInt &a, const APInt &b) {
629 return llvm::APIntOps::smax(a, b);
630 });
631 }
632
633 //===----------------------------------------------------------------------===//
634 // MaxUIOp
635 //===----------------------------------------------------------------------===//
636
fold(ArrayRef<Attribute> operands)637 OpFoldResult MaxUIOp::fold(ArrayRef<Attribute> operands) {
638 assert(operands.size() == 2 && "binary operation takes two operands");
639
640 // maxui(x,x) -> x
641 if (getLhs() == getRhs())
642 return getRhs();
643
644 APInt intValue;
645 // maxui(x,MAX_INT) -> MAX_INT
646 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
647 return getRhs();
648
649 // maxui(x, MIN_INT) -> x
650 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
651 return getLhs();
652
653 return constFoldBinaryOp<IntegerAttr>(operands,
654 [](const APInt &a, const APInt &b) {
655 return llvm::APIntOps::umax(a, b);
656 });
657 }
658
659 //===----------------------------------------------------------------------===//
660 // MinFOp
661 //===----------------------------------------------------------------------===//
662
fold(ArrayRef<Attribute> operands)663 OpFoldResult arith::MinFOp::fold(ArrayRef<Attribute> operands) {
664 assert(operands.size() == 2 && "minf takes two operands");
665
666 // minf(x,x) -> x
667 if (getLhs() == getRhs())
668 return getRhs();
669
670 // minf(x, +inf) -> x
671 if (matchPattern(getRhs(), m_PosInfFloat()))
672 return getLhs();
673
674 return constFoldBinaryOp<FloatAttr>(
675 operands,
676 [](const APFloat &a, const APFloat &b) { return llvm::minimum(a, b); });
677 }
678
679 //===----------------------------------------------------------------------===//
680 // MinSIOp
681 //===----------------------------------------------------------------------===//
682
fold(ArrayRef<Attribute> operands)683 OpFoldResult MinSIOp::fold(ArrayRef<Attribute> operands) {
684 assert(operands.size() == 2 && "binary operation takes two operands");
685
686 // minsi(x,x) -> x
687 if (getLhs() == getRhs())
688 return getRhs();
689
690 APInt intValue;
691 // minsi(x,MIN_INT) -> MIN_INT
692 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
693 intValue.isMinSignedValue())
694 return getRhs();
695
696 // minsi(x, MAX_INT) -> x
697 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) &&
698 intValue.isMaxSignedValue())
699 return getLhs();
700
701 return constFoldBinaryOp<IntegerAttr>(operands,
702 [](const APInt &a, const APInt &b) {
703 return llvm::APIntOps::smin(a, b);
704 });
705 }
706
707 //===----------------------------------------------------------------------===//
708 // MinUIOp
709 //===----------------------------------------------------------------------===//
710
fold(ArrayRef<Attribute> operands)711 OpFoldResult MinUIOp::fold(ArrayRef<Attribute> operands) {
712 assert(operands.size() == 2 && "binary operation takes two operands");
713
714 // minui(x,x) -> x
715 if (getLhs() == getRhs())
716 return getRhs();
717
718 APInt intValue;
719 // minui(x,MIN_INT) -> MIN_INT
720 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMinValue())
721 return getRhs();
722
723 // minui(x, MAX_INT) -> x
724 if (matchPattern(getRhs(), m_ConstantInt(&intValue)) && intValue.isMaxValue())
725 return getLhs();
726
727 return constFoldBinaryOp<IntegerAttr>(operands,
728 [](const APInt &a, const APInt &b) {
729 return llvm::APIntOps::umin(a, b);
730 });
731 }
732
733 //===----------------------------------------------------------------------===//
734 // MulFOp
735 //===----------------------------------------------------------------------===//
736
fold(ArrayRef<Attribute> operands)737 OpFoldResult arith::MulFOp::fold(ArrayRef<Attribute> operands) {
738 // mulf(x, 1) -> x
739 if (matchPattern(getRhs(), m_OneFloat()))
740 return getLhs();
741
742 return constFoldBinaryOp<FloatAttr>(
743 operands, [](const APFloat &a, const APFloat &b) { return a * b; });
744 }
745
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)746 void arith::MulFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
747 MLIRContext *context) {
748 patterns.add<MulFOfNegF>(context);
749 }
750
751 //===----------------------------------------------------------------------===//
752 // DivFOp
753 //===----------------------------------------------------------------------===//
754
fold(ArrayRef<Attribute> operands)755 OpFoldResult arith::DivFOp::fold(ArrayRef<Attribute> operands) {
756 // divf(x, 1) -> x
757 if (matchPattern(getRhs(), m_OneFloat()))
758 return getLhs();
759
760 return constFoldBinaryOp<FloatAttr>(
761 operands, [](const APFloat &a, const APFloat &b) { return a / b; });
762 }
763
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)764 void arith::DivFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
765 MLIRContext *context) {
766 patterns.add<DivFOfNegF>(context);
767 }
768
769 //===----------------------------------------------------------------------===//
770 // RemFOp
771 //===----------------------------------------------------------------------===//
772
fold(ArrayRef<Attribute> operands)773 OpFoldResult arith::RemFOp::fold(ArrayRef<Attribute> operands) {
774 return constFoldBinaryOp<FloatAttr>(operands,
775 [](const APFloat &a, const APFloat &b) {
776 APFloat result(a);
777 (void)result.remainder(b);
778 return result;
779 });
780 }
781
782 //===----------------------------------------------------------------------===//
783 // Utility functions for verifying cast ops
784 //===----------------------------------------------------------------------===//
785
786 template <typename... Types>
787 using type_list = std::tuple<Types...> *;
788
789 /// Returns a non-null type only if the provided type is one of the allowed
790 /// types or one of the allowed shaped types of the allowed types. Returns the
791 /// element type if a valid shaped type is provided.
792 template <typename... ShapedTypes, typename... ElementTypes>
getUnderlyingType(Type type,type_list<ShapedTypes...>,type_list<ElementTypes...>)793 static Type getUnderlyingType(Type type, type_list<ShapedTypes...>,
794 type_list<ElementTypes...>) {
795 if (type.isa<ShapedType>() && !type.isa<ShapedTypes...>())
796 return {};
797
798 auto underlyingType = getElementTypeOrSelf(type);
799 if (!underlyingType.isa<ElementTypes...>())
800 return {};
801
802 return underlyingType;
803 }
804
805 /// Get allowed underlying types for vectors and tensors.
806 template <typename... ElementTypes>
getTypeIfLike(Type type)807 static Type getTypeIfLike(Type type) {
808 return getUnderlyingType(type, type_list<VectorType, TensorType>(),
809 type_list<ElementTypes...>());
810 }
811
812 /// Get allowed underlying types for vectors, tensors, and memrefs.
813 template <typename... ElementTypes>
getTypeIfLikeOrMemRef(Type type)814 static Type getTypeIfLikeOrMemRef(Type type) {
815 return getUnderlyingType(type,
816 type_list<VectorType, TensorType, MemRefType>(),
817 type_list<ElementTypes...>());
818 }
819
areValidCastInputsAndOutputs(TypeRange inputs,TypeRange outputs)820 static bool areValidCastInputsAndOutputs(TypeRange inputs, TypeRange outputs) {
821 return inputs.size() == 1 && outputs.size() == 1 &&
822 succeeded(verifyCompatibleShapes(inputs.front(), outputs.front()));
823 }
824
825 //===----------------------------------------------------------------------===//
826 // Verifiers for integer and floating point extension/truncation ops
827 //===----------------------------------------------------------------------===//
828
829 // Extend ops can only extend to a wider type.
830 template <typename ValType, typename Op>
verifyExtOp(Op op)831 static LogicalResult verifyExtOp(Op op) {
832 Type srcType = getElementTypeOrSelf(op.getIn().getType());
833 Type dstType = getElementTypeOrSelf(op.getType());
834
835 if (srcType.cast<ValType>().getWidth() >= dstType.cast<ValType>().getWidth())
836 return op.emitError("result type ")
837 << dstType << " must be wider than operand type " << srcType;
838
839 return success();
840 }
841
842 // Truncate ops can only truncate to a shorter type.
843 template <typename ValType, typename Op>
verifyTruncateOp(Op op)844 static LogicalResult verifyTruncateOp(Op op) {
845 Type srcType = getElementTypeOrSelf(op.getIn().getType());
846 Type dstType = getElementTypeOrSelf(op.getType());
847
848 if (srcType.cast<ValType>().getWidth() <= dstType.cast<ValType>().getWidth())
849 return op.emitError("result type ")
850 << dstType << " must be shorter than operand type " << srcType;
851
852 return success();
853 }
854
855 /// Validate a cast that changes the width of a type.
856 template <template <typename> class WidthComparator, typename... ElementTypes>
checkWidthChangeCast(TypeRange inputs,TypeRange outputs)857 static bool checkWidthChangeCast(TypeRange inputs, TypeRange outputs) {
858 if (!areValidCastInputsAndOutputs(inputs, outputs))
859 return false;
860
861 auto srcType = getTypeIfLike<ElementTypes...>(inputs.front());
862 auto dstType = getTypeIfLike<ElementTypes...>(outputs.front());
863 if (!srcType || !dstType)
864 return false;
865
866 return WidthComparator<unsigned>()(dstType.getIntOrFloatBitWidth(),
867 srcType.getIntOrFloatBitWidth());
868 }
869
870 //===----------------------------------------------------------------------===//
871 // ExtUIOp
872 //===----------------------------------------------------------------------===//
873
fold(ArrayRef<Attribute> operands)874 OpFoldResult arith::ExtUIOp::fold(ArrayRef<Attribute> operands) {
875 if (auto lhs = getIn().getDefiningOp<ExtUIOp>()) {
876 getInMutable().assign(lhs.getIn());
877 return getResult();
878 }
879 Type resType = getType();
880 unsigned bitWidth;
881 if (auto shapedType = resType.dyn_cast<ShapedType>())
882 bitWidth = shapedType.getElementTypeBitWidth();
883 else
884 bitWidth = resType.getIntOrFloatBitWidth();
885 return constFoldCastOp<IntegerAttr, IntegerAttr>(
886 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
887 return a.zext(bitWidth);
888 });
889 }
890
areCastCompatible(TypeRange inputs,TypeRange outputs)891 bool arith::ExtUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
892 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
893 }
894
verify()895 LogicalResult arith::ExtUIOp::verify() {
896 return verifyExtOp<IntegerType>(*this);
897 }
898
899 //===----------------------------------------------------------------------===//
900 // ExtSIOp
901 //===----------------------------------------------------------------------===//
902
fold(ArrayRef<Attribute> operands)903 OpFoldResult arith::ExtSIOp::fold(ArrayRef<Attribute> operands) {
904 if (auto lhs = getIn().getDefiningOp<ExtSIOp>()) {
905 getInMutable().assign(lhs.getIn());
906 return getResult();
907 }
908 Type resType = getType();
909 unsigned bitWidth;
910 if (auto shapedType = resType.dyn_cast<ShapedType>())
911 bitWidth = shapedType.getElementTypeBitWidth();
912 else
913 bitWidth = resType.getIntOrFloatBitWidth();
914 return constFoldCastOp<IntegerAttr, IntegerAttr>(
915 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
916 return a.sext(bitWidth);
917 });
918 }
919
areCastCompatible(TypeRange inputs,TypeRange outputs)920 bool arith::ExtSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
921 return checkWidthChangeCast<std::greater, IntegerType>(inputs, outputs);
922 }
923
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)924 void arith::ExtSIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
925 MLIRContext *context) {
926 patterns.add<ExtSIOfExtUI>(context);
927 }
928
verify()929 LogicalResult arith::ExtSIOp::verify() {
930 return verifyExtOp<IntegerType>(*this);
931 }
932
933 //===----------------------------------------------------------------------===//
934 // ExtFOp
935 //===----------------------------------------------------------------------===//
936
areCastCompatible(TypeRange inputs,TypeRange outputs)937 bool arith::ExtFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
938 return checkWidthChangeCast<std::greater, FloatType>(inputs, outputs);
939 }
940
verify()941 LogicalResult arith::ExtFOp::verify() { return verifyExtOp<FloatType>(*this); }
942
943 //===----------------------------------------------------------------------===//
944 // TruncIOp
945 //===----------------------------------------------------------------------===//
946
fold(ArrayRef<Attribute> operands)947 OpFoldResult arith::TruncIOp::fold(ArrayRef<Attribute> operands) {
948 assert(operands.size() == 1 && "unary operation takes one operand");
949
950 // trunci(zexti(a)) -> a
951 // trunci(sexti(a)) -> a
952 if (matchPattern(getOperand(), m_Op<arith::ExtUIOp>()) ||
953 matchPattern(getOperand(), m_Op<arith::ExtSIOp>()))
954 return getOperand().getDefiningOp()->getOperand(0);
955
956 // trunci(trunci(a)) -> trunci(a))
957 if (matchPattern(getOperand(), m_Op<arith::TruncIOp>())) {
958 setOperand(getOperand().getDefiningOp()->getOperand(0));
959 return getResult();
960 }
961
962 Type resType = getType();
963 unsigned bitWidth;
964 if (auto shapedType = resType.dyn_cast<ShapedType>())
965 bitWidth = shapedType.getElementTypeBitWidth();
966 else
967 bitWidth = resType.getIntOrFloatBitWidth();
968
969 return constFoldCastOp<IntegerAttr, IntegerAttr>(
970 operands, getType(), [bitWidth](const APInt &a, bool &castStatus) {
971 return a.trunc(bitWidth);
972 });
973 }
974
areCastCompatible(TypeRange inputs,TypeRange outputs)975 bool arith::TruncIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
976 return checkWidthChangeCast<std::less, IntegerType>(inputs, outputs);
977 }
978
verify()979 LogicalResult arith::TruncIOp::verify() {
980 return verifyTruncateOp<IntegerType>(*this);
981 }
982
983 //===----------------------------------------------------------------------===//
984 // TruncFOp
985 //===----------------------------------------------------------------------===//
986
987 /// Perform safe const propagation for truncf, i.e. only propagate if FP value
988 /// can be represented without precision loss or rounding.
fold(ArrayRef<Attribute> operands)989 OpFoldResult arith::TruncFOp::fold(ArrayRef<Attribute> operands) {
990 assert(operands.size() == 1 && "unary operation takes one operand");
991
992 auto constOperand = operands.front();
993 if (!constOperand || !constOperand.isa<FloatAttr>())
994 return {};
995
996 // Convert to target type via 'double'.
997 double sourceValue =
998 constOperand.dyn_cast<FloatAttr>().getValue().convertToDouble();
999 auto targetAttr = FloatAttr::get(getType(), sourceValue);
1000
1001 // Propagate if constant's value does not change after truncation.
1002 if (sourceValue == targetAttr.getValue().convertToDouble())
1003 return targetAttr;
1004
1005 return {};
1006 }
1007
areCastCompatible(TypeRange inputs,TypeRange outputs)1008 bool arith::TruncFOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1009 return checkWidthChangeCast<std::less, FloatType>(inputs, outputs);
1010 }
1011
verify()1012 LogicalResult arith::TruncFOp::verify() {
1013 return verifyTruncateOp<FloatType>(*this);
1014 }
1015
1016 //===----------------------------------------------------------------------===//
1017 // AndIOp
1018 //===----------------------------------------------------------------------===//
1019
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1020 void arith::AndIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1021 MLIRContext *context) {
1022 patterns.add<AndOfExtUI, AndOfExtSI>(context);
1023 }
1024
1025 //===----------------------------------------------------------------------===//
1026 // OrIOp
1027 //===----------------------------------------------------------------------===//
1028
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1029 void arith::OrIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1030 MLIRContext *context) {
1031 patterns.add<OrOfExtUI, OrOfExtSI>(context);
1032 }
1033
1034 //===----------------------------------------------------------------------===//
1035 // Verifiers for casts between integers and floats.
1036 //===----------------------------------------------------------------------===//
1037
1038 template <typename From, typename To>
checkIntFloatCast(TypeRange inputs,TypeRange outputs)1039 static bool checkIntFloatCast(TypeRange inputs, TypeRange outputs) {
1040 if (!areValidCastInputsAndOutputs(inputs, outputs))
1041 return false;
1042
1043 auto srcType = getTypeIfLike<From>(inputs.front());
1044 auto dstType = getTypeIfLike<To>(outputs.back());
1045
1046 return srcType && dstType;
1047 }
1048
1049 //===----------------------------------------------------------------------===//
1050 // UIToFPOp
1051 //===----------------------------------------------------------------------===//
1052
areCastCompatible(TypeRange inputs,TypeRange outputs)1053 bool arith::UIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1054 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1055 }
1056
fold(ArrayRef<Attribute> operands)1057 OpFoldResult arith::UIToFPOp::fold(ArrayRef<Attribute> operands) {
1058 Type resType = getType();
1059 Type resEleType;
1060 if (auto shapedType = resType.dyn_cast<ShapedType>())
1061 resEleType = shapedType.getElementType();
1062 else
1063 resEleType = resType;
1064 return constFoldCastOp<IntegerAttr, FloatAttr>(
1065 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1066 FloatType floatTy = resEleType.cast<FloatType>();
1067 APFloat apf(floatTy.getFloatSemantics(),
1068 APInt::getZero(floatTy.getWidth()));
1069 apf.convertFromAPInt(a, /*IsSigned=*/false,
1070 APFloat::rmNearestTiesToEven);
1071 return apf;
1072 });
1073 }
1074
1075 //===----------------------------------------------------------------------===//
1076 // SIToFPOp
1077 //===----------------------------------------------------------------------===//
1078
areCastCompatible(TypeRange inputs,TypeRange outputs)1079 bool arith::SIToFPOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1080 return checkIntFloatCast<IntegerType, FloatType>(inputs, outputs);
1081 }
1082
fold(ArrayRef<Attribute> operands)1083 OpFoldResult arith::SIToFPOp::fold(ArrayRef<Attribute> operands) {
1084 Type resType = getType();
1085 Type resEleType;
1086 if (auto shapedType = resType.dyn_cast<ShapedType>())
1087 resEleType = shapedType.getElementType();
1088 else
1089 resEleType = resType;
1090 return constFoldCastOp<IntegerAttr, FloatAttr>(
1091 operands, getType(), [&resEleType](const APInt &a, bool &castStatus) {
1092 FloatType floatTy = resEleType.cast<FloatType>();
1093 APFloat apf(floatTy.getFloatSemantics(),
1094 APInt::getZero(floatTy.getWidth()));
1095 apf.convertFromAPInt(a, /*IsSigned=*/true,
1096 APFloat::rmNearestTiesToEven);
1097 return apf;
1098 });
1099 }
1100 //===----------------------------------------------------------------------===//
1101 // FPToUIOp
1102 //===----------------------------------------------------------------------===//
1103
areCastCompatible(TypeRange inputs,TypeRange outputs)1104 bool arith::FPToUIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1105 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1106 }
1107
fold(ArrayRef<Attribute> operands)1108 OpFoldResult arith::FPToUIOp::fold(ArrayRef<Attribute> operands) {
1109 Type resType = getType();
1110 Type resEleType;
1111 if (auto shapedType = resType.dyn_cast<ShapedType>())
1112 resEleType = shapedType.getElementType();
1113 else
1114 resEleType = resType;
1115 return constFoldCastOp<FloatAttr, IntegerAttr>(
1116 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1117 IntegerType intTy = resEleType.cast<IntegerType>();
1118 bool ignored;
1119 APSInt api(intTy.getWidth(), /*isUnsigned=*/true);
1120 castStatus = APFloat::opInvalidOp !=
1121 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1122 return api;
1123 });
1124 }
1125
1126 //===----------------------------------------------------------------------===//
1127 // FPToSIOp
1128 //===----------------------------------------------------------------------===//
1129
areCastCompatible(TypeRange inputs,TypeRange outputs)1130 bool arith::FPToSIOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1131 return checkIntFloatCast<FloatType, IntegerType>(inputs, outputs);
1132 }
1133
fold(ArrayRef<Attribute> operands)1134 OpFoldResult arith::FPToSIOp::fold(ArrayRef<Attribute> operands) {
1135 Type resType = getType();
1136 Type resEleType;
1137 if (auto shapedType = resType.dyn_cast<ShapedType>())
1138 resEleType = shapedType.getElementType();
1139 else
1140 resEleType = resType;
1141 return constFoldCastOp<FloatAttr, IntegerAttr>(
1142 operands, getType(), [&resEleType](const APFloat &a, bool &castStatus) {
1143 IntegerType intTy = resEleType.cast<IntegerType>();
1144 bool ignored;
1145 APSInt api(intTy.getWidth(), /*isUnsigned=*/false);
1146 castStatus = APFloat::opInvalidOp !=
1147 a.convertToInteger(api, APFloat::rmTowardZero, &ignored);
1148 return api;
1149 });
1150 }
1151
1152 //===----------------------------------------------------------------------===//
1153 // IndexCastOp
1154 //===----------------------------------------------------------------------===//
1155
areCastCompatible(TypeRange inputs,TypeRange outputs)1156 bool arith::IndexCastOp::areCastCompatible(TypeRange inputs,
1157 TypeRange outputs) {
1158 if (!areValidCastInputsAndOutputs(inputs, outputs))
1159 return false;
1160
1161 auto srcType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(inputs.front());
1162 auto dstType = getTypeIfLikeOrMemRef<IntegerType, IndexType>(outputs.front());
1163 if (!srcType || !dstType)
1164 return false;
1165
1166 return (srcType.isIndex() && dstType.isSignlessInteger()) ||
1167 (srcType.isSignlessInteger() && dstType.isIndex());
1168 }
1169
fold(ArrayRef<Attribute> operands)1170 OpFoldResult arith::IndexCastOp::fold(ArrayRef<Attribute> operands) {
1171 // index_cast(constant) -> constant
1172 // A little hack because we go through int. Otherwise, the size of the
1173 // constant might need to change.
1174 if (auto value = operands[0].dyn_cast_or_null<IntegerAttr>())
1175 return IntegerAttr::get(getType(), value.getInt());
1176
1177 return {};
1178 }
1179
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1180 void arith::IndexCastOp::getCanonicalizationPatterns(
1181 RewritePatternSet &patterns, MLIRContext *context) {
1182 patterns.add<IndexCastOfIndexCast, IndexCastOfExtSI>(context);
1183 }
1184
1185 //===----------------------------------------------------------------------===//
1186 // BitcastOp
1187 //===----------------------------------------------------------------------===//
1188
areCastCompatible(TypeRange inputs,TypeRange outputs)1189 bool arith::BitcastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1190 if (!areValidCastInputsAndOutputs(inputs, outputs))
1191 return false;
1192
1193 auto srcType =
1194 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(inputs.front());
1195 auto dstType =
1196 getTypeIfLikeOrMemRef<IntegerType, IndexType, FloatType>(outputs.front());
1197 if (!srcType || !dstType)
1198 return false;
1199
1200 return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
1201 }
1202
fold(ArrayRef<Attribute> operands)1203 OpFoldResult arith::BitcastOp::fold(ArrayRef<Attribute> operands) {
1204 assert(operands.size() == 1 && "bitcast op expects 1 operand");
1205
1206 auto resType = getType();
1207 auto operand = operands[0];
1208 if (!operand)
1209 return {};
1210
1211 /// Bitcast dense elements.
1212 if (auto denseAttr = operand.dyn_cast_or_null<DenseElementsAttr>())
1213 return denseAttr.bitcast(resType.cast<ShapedType>().getElementType());
1214 /// Other shaped types unhandled.
1215 if (resType.isa<ShapedType>())
1216 return {};
1217
1218 /// Bitcast integer or float to integer or float.
1219 APInt bits = operand.isa<FloatAttr>()
1220 ? operand.cast<FloatAttr>().getValue().bitcastToAPInt()
1221 : operand.cast<IntegerAttr>().getValue();
1222
1223 if (auto resFloatType = resType.dyn_cast<FloatType>())
1224 return FloatAttr::get(resType,
1225 APFloat(resFloatType.getFloatSemantics(), bits));
1226 return IntegerAttr::get(resType, bits);
1227 }
1228
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1229 void arith::BitcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1230 MLIRContext *context) {
1231 patterns.add<BitcastOfBitcast>(context);
1232 }
1233
1234 //===----------------------------------------------------------------------===//
1235 // Helpers for compare ops
1236 //===----------------------------------------------------------------------===//
1237
1238 /// Return the type of the same shape (scalar, vector or tensor) containing i1.
getI1SameShape(Type type)1239 static Type getI1SameShape(Type type) {
1240 auto i1Type = IntegerType::get(type.getContext(), 1);
1241 if (auto tensorType = type.dyn_cast<RankedTensorType>())
1242 return RankedTensorType::get(tensorType.getShape(), i1Type);
1243 if (type.isa<UnrankedTensorType>())
1244 return UnrankedTensorType::get(i1Type);
1245 if (auto vectorType = type.dyn_cast<VectorType>())
1246 return VectorType::get(vectorType.getShape(), i1Type,
1247 vectorType.getNumScalableDims());
1248 return i1Type;
1249 }
1250
1251 //===----------------------------------------------------------------------===//
1252 // CmpIOp
1253 //===----------------------------------------------------------------------===//
1254
1255 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer
1256 /// comparison predicates.
applyCmpPredicate(arith::CmpIPredicate predicate,const APInt & lhs,const APInt & rhs)1257 bool mlir::arith::applyCmpPredicate(arith::CmpIPredicate predicate,
1258 const APInt &lhs, const APInt &rhs) {
1259 switch (predicate) {
1260 case arith::CmpIPredicate::eq:
1261 return lhs.eq(rhs);
1262 case arith::CmpIPredicate::ne:
1263 return lhs.ne(rhs);
1264 case arith::CmpIPredicate::slt:
1265 return lhs.slt(rhs);
1266 case arith::CmpIPredicate::sle:
1267 return lhs.sle(rhs);
1268 case arith::CmpIPredicate::sgt:
1269 return lhs.sgt(rhs);
1270 case arith::CmpIPredicate::sge:
1271 return lhs.sge(rhs);
1272 case arith::CmpIPredicate::ult:
1273 return lhs.ult(rhs);
1274 case arith::CmpIPredicate::ule:
1275 return lhs.ule(rhs);
1276 case arith::CmpIPredicate::ugt:
1277 return lhs.ugt(rhs);
1278 case arith::CmpIPredicate::uge:
1279 return lhs.uge(rhs);
1280 }
1281 llvm_unreachable("unknown cmpi predicate kind");
1282 }
1283
1284 /// Returns true if the predicate is true for two equal operands.
applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate)1285 static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) {
1286 switch (predicate) {
1287 case arith::CmpIPredicate::eq:
1288 case arith::CmpIPredicate::sle:
1289 case arith::CmpIPredicate::sge:
1290 case arith::CmpIPredicate::ule:
1291 case arith::CmpIPredicate::uge:
1292 return true;
1293 case arith::CmpIPredicate::ne:
1294 case arith::CmpIPredicate::slt:
1295 case arith::CmpIPredicate::sgt:
1296 case arith::CmpIPredicate::ult:
1297 case arith::CmpIPredicate::ugt:
1298 return false;
1299 }
1300 llvm_unreachable("unknown cmpi predicate kind");
1301 }
1302
getBoolAttribute(Type type,MLIRContext * ctx,bool value)1303 static Attribute getBoolAttribute(Type type, MLIRContext *ctx, bool value) {
1304 auto boolAttr = BoolAttr::get(ctx, value);
1305 ShapedType shapedType = type.dyn_cast_or_null<ShapedType>();
1306 if (!shapedType)
1307 return boolAttr;
1308 return DenseElementsAttr::get(shapedType, boolAttr);
1309 }
1310
fold(ArrayRef<Attribute> operands)1311 OpFoldResult arith::CmpIOp::fold(ArrayRef<Attribute> operands) {
1312 assert(operands.size() == 2 && "cmpi takes two operands");
1313
1314 // cmpi(pred, x, x)
1315 if (getLhs() == getRhs()) {
1316 auto val = applyCmpPredicateToEqualOperands(getPredicate());
1317 return getBoolAttribute(getType(), getContext(), val);
1318 }
1319
1320 if (matchPattern(getRhs(), m_Zero())) {
1321 if (auto extOp = getLhs().getDefiningOp<ExtSIOp>()) {
1322 // extsi(%x : i1 -> iN) != 0 -> %x
1323 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1324 getPredicate() == arith::CmpIPredicate::ne)
1325 return extOp.getOperand();
1326 }
1327 if (auto extOp = getLhs().getDefiningOp<ExtUIOp>()) {
1328 // extui(%x : i1 -> iN) != 0 -> %x
1329 if (extOp.getOperand().getType().cast<IntegerType>().getWidth() == 1 &&
1330 getPredicate() == arith::CmpIPredicate::ne)
1331 return extOp.getOperand();
1332 }
1333 }
1334
1335 // Move constant to the right side.
1336 if (operands[0] && !operands[1]) {
1337 // Do not use invertPredicate, as it will change eq to ne and vice versa.
1338 using Pred = CmpIPredicate;
1339 const std::pair<Pred, Pred> invPreds[] = {
1340 {Pred::slt, Pred::sgt}, {Pred::sgt, Pred::slt}, {Pred::sle, Pred::sge},
1341 {Pred::sge, Pred::sle}, {Pred::ult, Pred::ugt}, {Pred::ugt, Pred::ult},
1342 {Pred::ule, Pred::uge}, {Pred::uge, Pred::ule}, {Pred::eq, Pred::eq},
1343 {Pred::ne, Pred::ne},
1344 };
1345 Pred origPred = getPredicate();
1346 for (auto pred : invPreds) {
1347 if (origPred == pred.first) {
1348 setPredicateAttr(CmpIPredicateAttr::get(getContext(), pred.second));
1349 Value lhs = getLhs();
1350 Value rhs = getRhs();
1351 getLhsMutable().assign(rhs);
1352 getRhsMutable().assign(lhs);
1353 return getResult();
1354 }
1355 }
1356 llvm_unreachable("unknown cmpi predicate kind");
1357 }
1358
1359 auto lhs = operands.front().dyn_cast_or_null<IntegerAttr>();
1360 if (!lhs)
1361 return {};
1362
1363 // We are moving constants to the right side; So if lhs is constant rhs is
1364 // guaranteed to be a constant.
1365 auto rhs = operands.back().cast<IntegerAttr>();
1366
1367 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1368 return BoolAttr::get(getContext(), val);
1369 }
1370
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1371 void arith::CmpIOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1372 MLIRContext *context) {
1373 patterns.insert<CmpIExtSI, CmpIExtUI>(context);
1374 }
1375
1376 //===----------------------------------------------------------------------===//
1377 // CmpFOp
1378 //===----------------------------------------------------------------------===//
1379
1380 /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point
1381 /// comparison predicates.
applyCmpPredicate(arith::CmpFPredicate predicate,const APFloat & lhs,const APFloat & rhs)1382 bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate,
1383 const APFloat &lhs, const APFloat &rhs) {
1384 auto cmpResult = lhs.compare(rhs);
1385 switch (predicate) {
1386 case arith::CmpFPredicate::AlwaysFalse:
1387 return false;
1388 case arith::CmpFPredicate::OEQ:
1389 return cmpResult == APFloat::cmpEqual;
1390 case arith::CmpFPredicate::OGT:
1391 return cmpResult == APFloat::cmpGreaterThan;
1392 case arith::CmpFPredicate::OGE:
1393 return cmpResult == APFloat::cmpGreaterThan ||
1394 cmpResult == APFloat::cmpEqual;
1395 case arith::CmpFPredicate::OLT:
1396 return cmpResult == APFloat::cmpLessThan;
1397 case arith::CmpFPredicate::OLE:
1398 return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1399 case arith::CmpFPredicate::ONE:
1400 return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual;
1401 case arith::CmpFPredicate::ORD:
1402 return cmpResult != APFloat::cmpUnordered;
1403 case arith::CmpFPredicate::UEQ:
1404 return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual;
1405 case arith::CmpFPredicate::UGT:
1406 return cmpResult == APFloat::cmpUnordered ||
1407 cmpResult == APFloat::cmpGreaterThan;
1408 case arith::CmpFPredicate::UGE:
1409 return cmpResult == APFloat::cmpUnordered ||
1410 cmpResult == APFloat::cmpGreaterThan ||
1411 cmpResult == APFloat::cmpEqual;
1412 case arith::CmpFPredicate::ULT:
1413 return cmpResult == APFloat::cmpUnordered ||
1414 cmpResult == APFloat::cmpLessThan;
1415 case arith::CmpFPredicate::ULE:
1416 return cmpResult == APFloat::cmpUnordered ||
1417 cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual;
1418 case arith::CmpFPredicate::UNE:
1419 return cmpResult != APFloat::cmpEqual;
1420 case arith::CmpFPredicate::UNO:
1421 return cmpResult == APFloat::cmpUnordered;
1422 case arith::CmpFPredicate::AlwaysTrue:
1423 return true;
1424 }
1425 llvm_unreachable("unknown cmpf predicate kind");
1426 }
1427
fold(ArrayRef<Attribute> operands)1428 OpFoldResult arith::CmpFOp::fold(ArrayRef<Attribute> operands) {
1429 assert(operands.size() == 2 && "cmpf takes two operands");
1430
1431 auto lhs = operands.front().dyn_cast_or_null<FloatAttr>();
1432 auto rhs = operands.back().dyn_cast_or_null<FloatAttr>();
1433
1434 // If one operand is NaN, making them both NaN does not change the result.
1435 if (lhs && lhs.getValue().isNaN())
1436 rhs = lhs;
1437 if (rhs && rhs.getValue().isNaN())
1438 lhs = rhs;
1439
1440 if (!lhs || !rhs)
1441 return {};
1442
1443 auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
1444 return BoolAttr::get(getContext(), val);
1445 }
1446
1447 class CmpFIntToFPConst final : public OpRewritePattern<CmpFOp> {
1448 public:
1449 using OpRewritePattern<CmpFOp>::OpRewritePattern;
1450
convertToIntegerPredicate(CmpFPredicate pred,bool isUnsigned)1451 static CmpIPredicate convertToIntegerPredicate(CmpFPredicate pred,
1452 bool isUnsigned) {
1453 using namespace arith;
1454 switch (pred) {
1455 case CmpFPredicate::UEQ:
1456 case CmpFPredicate::OEQ:
1457 return CmpIPredicate::eq;
1458 case CmpFPredicate::UGT:
1459 case CmpFPredicate::OGT:
1460 return isUnsigned ? CmpIPredicate::ugt : CmpIPredicate::sgt;
1461 case CmpFPredicate::UGE:
1462 case CmpFPredicate::OGE:
1463 return isUnsigned ? CmpIPredicate::uge : CmpIPredicate::sge;
1464 case CmpFPredicate::ULT:
1465 case CmpFPredicate::OLT:
1466 return isUnsigned ? CmpIPredicate::ult : CmpIPredicate::slt;
1467 case CmpFPredicate::ULE:
1468 case CmpFPredicate::OLE:
1469 return isUnsigned ? CmpIPredicate::ule : CmpIPredicate::sle;
1470 case CmpFPredicate::UNE:
1471 case CmpFPredicate::ONE:
1472 return CmpIPredicate::ne;
1473 default:
1474 llvm_unreachable("Unexpected predicate!");
1475 }
1476 }
1477
matchAndRewrite(CmpFOp op,PatternRewriter & rewriter) const1478 LogicalResult matchAndRewrite(CmpFOp op,
1479 PatternRewriter &rewriter) const override {
1480 FloatAttr flt;
1481 if (!matchPattern(op.getRhs(), m_Constant(&flt)))
1482 return failure();
1483
1484 const APFloat &rhs = flt.getValue();
1485
1486 // Don't attempt to fold a nan.
1487 if (rhs.isNaN())
1488 return failure();
1489
1490 // Get the width of the mantissa. We don't want to hack on conversions that
1491 // might lose information from the integer, e.g. "i64 -> float"
1492 FloatType floatTy = op.getRhs().getType().cast<FloatType>();
1493 int mantissaWidth = floatTy.getFPMantissaWidth();
1494 if (mantissaWidth <= 0)
1495 return failure();
1496
1497 bool isUnsigned;
1498 Value intVal;
1499
1500 if (auto si = op.getLhs().getDefiningOp<SIToFPOp>()) {
1501 isUnsigned = false;
1502 intVal = si.getIn();
1503 } else if (auto ui = op.getLhs().getDefiningOp<UIToFPOp>()) {
1504 isUnsigned = true;
1505 intVal = ui.getIn();
1506 } else {
1507 return failure();
1508 }
1509
1510 // Check to see that the input is converted from an integer type that is
1511 // small enough that preserves all bits.
1512 auto intTy = intVal.getType().cast<IntegerType>();
1513 auto intWidth = intTy.getWidth();
1514
1515 // Number of bits representing values, as opposed to the sign
1516 auto valueBits = isUnsigned ? intWidth : (intWidth - 1);
1517
1518 // Following test does NOT adjust intWidth downwards for signed inputs,
1519 // because the most negative value still requires all the mantissa bits
1520 // to distinguish it from one less than that value.
1521 if ((int)intWidth > mantissaWidth) {
1522 // Conversion would lose accuracy. Check if loss can impact comparison.
1523 int exponent = ilogb(rhs);
1524 if (exponent == APFloat::IEK_Inf) {
1525 int maxExponent = ilogb(APFloat::getLargest(rhs.getSemantics()));
1526 if (maxExponent < (int)valueBits) {
1527 // Conversion could create infinity.
1528 return failure();
1529 }
1530 } else {
1531 // Note that if rhs is zero or NaN, then Exp is negative
1532 // and first condition is trivially false.
1533 if (mantissaWidth <= exponent && exponent <= (int)valueBits) {
1534 // Conversion could affect comparison.
1535 return failure();
1536 }
1537 }
1538 }
1539
1540 // Convert to equivalent cmpi predicate
1541 CmpIPredicate pred;
1542 switch (op.getPredicate()) {
1543 case CmpFPredicate::ORD:
1544 // Int to fp conversion doesn't create a nan (ord checks neither is a nan)
1545 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1546 /*width=*/1);
1547 return success();
1548 case CmpFPredicate::UNO:
1549 // Int to fp conversion doesn't create a nan (uno checks either is a nan)
1550 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1551 /*width=*/1);
1552 return success();
1553 default:
1554 pred = convertToIntegerPredicate(op.getPredicate(), isUnsigned);
1555 break;
1556 }
1557
1558 if (!isUnsigned) {
1559 // If the rhs value is > SignedMax, fold the comparison. This handles
1560 // +INF and large values.
1561 APFloat signedMax(rhs.getSemantics());
1562 signedMax.convertFromAPInt(APInt::getSignedMaxValue(intWidth), true,
1563 APFloat::rmNearestTiesToEven);
1564 if (signedMax < rhs) { // smax < 13123.0
1565 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::slt ||
1566 pred == CmpIPredicate::sle)
1567 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1568 /*width=*/1);
1569 else
1570 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1571 /*width=*/1);
1572 return success();
1573 }
1574 } else {
1575 // If the rhs value is > UnsignedMax, fold the comparison. This handles
1576 // +INF and large values.
1577 APFloat unsignedMax(rhs.getSemantics());
1578 unsignedMax.convertFromAPInt(APInt::getMaxValue(intWidth), false,
1579 APFloat::rmNearestTiesToEven);
1580 if (unsignedMax < rhs) { // umax < 13123.0
1581 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ult ||
1582 pred == CmpIPredicate::ule)
1583 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1584 /*width=*/1);
1585 else
1586 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1587 /*width=*/1);
1588 return success();
1589 }
1590 }
1591
1592 if (!isUnsigned) {
1593 // See if the rhs value is < SignedMin.
1594 APFloat signedMin(rhs.getSemantics());
1595 signedMin.convertFromAPInt(APInt::getSignedMinValue(intWidth), true,
1596 APFloat::rmNearestTiesToEven);
1597 if (signedMin > rhs) { // smin > 12312.0
1598 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::sgt ||
1599 pred == CmpIPredicate::sge)
1600 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1601 /*width=*/1);
1602 else
1603 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1604 /*width=*/1);
1605 return success();
1606 }
1607 } else {
1608 // See if the rhs value is < UnsignedMin.
1609 APFloat unsignedMin(rhs.getSemantics());
1610 unsignedMin.convertFromAPInt(APInt::getMinValue(intWidth), false,
1611 APFloat::rmNearestTiesToEven);
1612 if (unsignedMin > rhs) { // umin > 12312.0
1613 if (pred == CmpIPredicate::ne || pred == CmpIPredicate::ugt ||
1614 pred == CmpIPredicate::uge)
1615 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1616 /*width=*/1);
1617 else
1618 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1619 /*width=*/1);
1620 return success();
1621 }
1622 }
1623
1624 // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or
1625 // [0, UMAX], but it may still be fractional. See if it is fractional by
1626 // casting the FP value to the integer value and back, checking for
1627 // equality. Don't do this for zero, because -0.0 is not fractional.
1628 bool ignored;
1629 APSInt rhsInt(intWidth, isUnsigned);
1630 if (APFloat::opInvalidOp ==
1631 rhs.convertToInteger(rhsInt, APFloat::rmTowardZero, &ignored)) {
1632 // Undefined behavior invoked - the destination type can't represent
1633 // the input constant.
1634 return failure();
1635 }
1636
1637 if (!rhs.isZero()) {
1638 APFloat apf(floatTy.getFloatSemantics(),
1639 APInt::getZero(floatTy.getWidth()));
1640 apf.convertFromAPInt(rhsInt, !isUnsigned, APFloat::rmNearestTiesToEven);
1641
1642 bool equal = apf == rhs;
1643 if (!equal) {
1644 // If we had a comparison against a fractional value, we have to adjust
1645 // the compare predicate and sometimes the value. rhsInt is rounded
1646 // towards zero at this point.
1647 switch (pred) {
1648 case CmpIPredicate::ne: // (float)int != 4.4 --> true
1649 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1650 /*width=*/1);
1651 return success();
1652 case CmpIPredicate::eq: // (float)int == 4.4 --> false
1653 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1654 /*width=*/1);
1655 return success();
1656 case CmpIPredicate::ule:
1657 // (float)int <= 4.4 --> int <= 4
1658 // (float)int <= -4.4 --> false
1659 if (rhs.isNegative()) {
1660 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1661 /*width=*/1);
1662 return success();
1663 }
1664 break;
1665 case CmpIPredicate::sle:
1666 // (float)int <= 4.4 --> int <= 4
1667 // (float)int <= -4.4 --> int < -4
1668 if (rhs.isNegative())
1669 pred = CmpIPredicate::slt;
1670 break;
1671 case CmpIPredicate::ult:
1672 // (float)int < -4.4 --> false
1673 // (float)int < 4.4 --> int <= 4
1674 if (rhs.isNegative()) {
1675 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/false,
1676 /*width=*/1);
1677 return success();
1678 }
1679 pred = CmpIPredicate::ule;
1680 break;
1681 case CmpIPredicate::slt:
1682 // (float)int < -4.4 --> int < -4
1683 // (float)int < 4.4 --> int <= 4
1684 if (!rhs.isNegative())
1685 pred = CmpIPredicate::sle;
1686 break;
1687 case CmpIPredicate::ugt:
1688 // (float)int > 4.4 --> int > 4
1689 // (float)int > -4.4 --> true
1690 if (rhs.isNegative()) {
1691 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1692 /*width=*/1);
1693 return success();
1694 }
1695 break;
1696 case CmpIPredicate::sgt:
1697 // (float)int > 4.4 --> int > 4
1698 // (float)int > -4.4 --> int >= -4
1699 if (rhs.isNegative())
1700 pred = CmpIPredicate::sge;
1701 break;
1702 case CmpIPredicate::uge:
1703 // (float)int >= -4.4 --> true
1704 // (float)int >= 4.4 --> int > 4
1705 if (rhs.isNegative()) {
1706 rewriter.replaceOpWithNewOp<ConstantIntOp>(op, /*value=*/true,
1707 /*width=*/1);
1708 return success();
1709 }
1710 pred = CmpIPredicate::ugt;
1711 break;
1712 case CmpIPredicate::sge:
1713 // (float)int >= -4.4 --> int >= -4
1714 // (float)int >= 4.4 --> int > 4
1715 if (!rhs.isNegative())
1716 pred = CmpIPredicate::sgt;
1717 break;
1718 }
1719 }
1720 }
1721
1722 // Lower this FP comparison into an appropriate integer version of the
1723 // comparison.
1724 rewriter.replaceOpWithNewOp<CmpIOp>(
1725 op, pred, intVal,
1726 rewriter.create<ConstantOp>(
1727 op.getLoc(), intVal.getType(),
1728 rewriter.getIntegerAttr(intVal.getType(), rhsInt)));
1729 return success();
1730 }
1731 };
1732
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)1733 void arith::CmpFOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1734 MLIRContext *context) {
1735 patterns.insert<CmpFIntToFPConst>(context);
1736 }
1737
1738 //===----------------------------------------------------------------------===//
1739 // SelectOp
1740 //===----------------------------------------------------------------------===//
1741
1742 // Transforms a select of a boolean to arithmetic operations
1743 //
1744 // arith.select %arg, %x, %y : i1
1745 //
1746 // becomes
1747 //
1748 // and(%arg, %x) or and(!%arg, %y)
1749 struct SelectI1Simplify : public OpRewritePattern<arith::SelectOp> {
1750 using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1751
matchAndRewriteSelectI1Simplify1752 LogicalResult matchAndRewrite(arith::SelectOp op,
1753 PatternRewriter &rewriter) const override {
1754 if (!op.getType().isInteger(1))
1755 return failure();
1756
1757 Value falseConstant =
1758 rewriter.create<arith::ConstantIntOp>(op.getLoc(), true, 1);
1759 Value notCondition = rewriter.create<arith::XOrIOp>(
1760 op.getLoc(), op.getCondition(), falseConstant);
1761
1762 Value trueVal = rewriter.create<arith::AndIOp>(
1763 op.getLoc(), op.getCondition(), op.getTrueValue());
1764 Value falseVal = rewriter.create<arith::AndIOp>(op.getLoc(), notCondition,
1765 op.getFalseValue());
1766 rewriter.replaceOpWithNewOp<arith::OrIOp>(op, trueVal, falseVal);
1767 return success();
1768 }
1769 };
1770
1771 // select %arg, %c1, %c0 => extui %arg
1772 struct SelectToExtUI : public OpRewritePattern<arith::SelectOp> {
1773 using OpRewritePattern<arith::SelectOp>::OpRewritePattern;
1774
matchAndRewriteSelectToExtUI1775 LogicalResult matchAndRewrite(arith::SelectOp op,
1776 PatternRewriter &rewriter) const override {
1777 // Cannot extui i1 to i1, or i1 to f32
1778 if (!op.getType().isa<IntegerType>() || op.getType().isInteger(1))
1779 return failure();
1780
1781 // select %x, c1, %c0 => extui %arg
1782 if (matchPattern(op.getTrueValue(), m_One()) &&
1783 matchPattern(op.getFalseValue(), m_Zero())) {
1784 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(op, op.getType(),
1785 op.getCondition());
1786 return success();
1787 }
1788
1789 // select %x, c0, %c1 => extui (xor %arg, true)
1790 if (matchPattern(op.getTrueValue(), m_Zero()) &&
1791 matchPattern(op.getFalseValue(), m_One())) {
1792 rewriter.replaceOpWithNewOp<arith::ExtUIOp>(
1793 op, op.getType(),
1794 rewriter.create<arith::XOrIOp>(
1795 op.getLoc(), op.getCondition(),
1796 rewriter.create<arith::ConstantIntOp>(
1797 op.getLoc(), 1, op.getCondition().getType())));
1798 return success();
1799 }
1800
1801 return failure();
1802 }
1803 };
1804
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1805 void arith::SelectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1806 MLIRContext *context) {
1807 results.add<SelectI1Simplify, SelectToExtUI>(context);
1808 }
1809
fold(ArrayRef<Attribute> operands)1810 OpFoldResult arith::SelectOp::fold(ArrayRef<Attribute> operands) {
1811 Value trueVal = getTrueValue();
1812 Value falseVal = getFalseValue();
1813 if (trueVal == falseVal)
1814 return trueVal;
1815
1816 Value condition = getCondition();
1817
1818 // select true, %0, %1 => %0
1819 if (matchPattern(condition, m_One()))
1820 return trueVal;
1821
1822 // select false, %0, %1 => %1
1823 if (matchPattern(condition, m_Zero()))
1824 return falseVal;
1825
1826 // select %x, true, false => %x
1827 if (getType().isInteger(1) && matchPattern(getTrueValue(), m_One()) &&
1828 matchPattern(getFalseValue(), m_Zero()))
1829 return condition;
1830
1831 if (auto cmp = dyn_cast_or_null<arith::CmpIOp>(condition.getDefiningOp())) {
1832 auto pred = cmp.getPredicate();
1833 if (pred == arith::CmpIPredicate::eq || pred == arith::CmpIPredicate::ne) {
1834 auto cmpLhs = cmp.getLhs();
1835 auto cmpRhs = cmp.getRhs();
1836
1837 // %0 = arith.cmpi eq, %arg0, %arg1
1838 // %1 = arith.select %0, %arg0, %arg1 => %arg1
1839
1840 // %0 = arith.cmpi ne, %arg0, %arg1
1841 // %1 = arith.select %0, %arg0, %arg1 => %arg0
1842
1843 if ((cmpLhs == trueVal && cmpRhs == falseVal) ||
1844 (cmpRhs == trueVal && cmpLhs == falseVal))
1845 return pred == arith::CmpIPredicate::ne ? trueVal : falseVal;
1846 }
1847 }
1848 return nullptr;
1849 }
1850
parse(OpAsmParser & parser,OperationState & result)1851 ParseResult SelectOp::parse(OpAsmParser &parser, OperationState &result) {
1852 Type conditionType, resultType;
1853 SmallVector<OpAsmParser::UnresolvedOperand, 3> operands;
1854 if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) ||
1855 parser.parseOptionalAttrDict(result.attributes) ||
1856 parser.parseColonType(resultType))
1857 return failure();
1858
1859 // Check for the explicit condition type if this is a masked tensor or vector.
1860 if (succeeded(parser.parseOptionalComma())) {
1861 conditionType = resultType;
1862 if (parser.parseType(resultType))
1863 return failure();
1864 } else {
1865 conditionType = parser.getBuilder().getI1Type();
1866 }
1867
1868 result.addTypes(resultType);
1869 return parser.resolveOperands(operands,
1870 {conditionType, resultType, resultType},
1871 parser.getNameLoc(), result.operands);
1872 }
1873
print(OpAsmPrinter & p)1874 void arith::SelectOp::print(OpAsmPrinter &p) {
1875 p << " " << getOperands();
1876 p.printOptionalAttrDict((*this)->getAttrs());
1877 p << " : ";
1878 if (ShapedType condType = getCondition().getType().dyn_cast<ShapedType>())
1879 p << condType << ", ";
1880 p << getType();
1881 }
1882
verify()1883 LogicalResult arith::SelectOp::verify() {
1884 Type conditionType = getCondition().getType();
1885 if (conditionType.isSignlessInteger(1))
1886 return success();
1887
1888 // If the result type is a vector or tensor, the type can be a mask with the
1889 // same elements.
1890 Type resultType = getType();
1891 if (!resultType.isa<TensorType, VectorType>())
1892 return emitOpError() << "expected condition to be a signless i1, but got "
1893 << conditionType;
1894 Type shapedConditionType = getI1SameShape(resultType);
1895 if (conditionType != shapedConditionType) {
1896 return emitOpError() << "expected condition type to have the same shape "
1897 "as the result type, expected "
1898 << shapedConditionType << ", but got "
1899 << conditionType;
1900 }
1901 return success();
1902 }
1903 //===----------------------------------------------------------------------===//
1904 // ShLIOp
1905 //===----------------------------------------------------------------------===//
1906
fold(ArrayRef<Attribute> operands)1907 OpFoldResult arith::ShLIOp::fold(ArrayRef<Attribute> operands) {
1908 // Don't fold if shifting more than the bit width.
1909 bool bounded = false;
1910 auto result = constFoldBinaryOp<IntegerAttr>(
1911 operands, [&](const APInt &a, const APInt &b) {
1912 bounded = b.ule(b.getBitWidth());
1913 return a.shl(b);
1914 });
1915 return bounded ? result : Attribute();
1916 }
1917
1918 //===----------------------------------------------------------------------===//
1919 // ShRUIOp
1920 //===----------------------------------------------------------------------===//
1921
fold(ArrayRef<Attribute> operands)1922 OpFoldResult arith::ShRUIOp::fold(ArrayRef<Attribute> operands) {
1923 // Don't fold if shifting more than the bit width.
1924 bool bounded = false;
1925 auto result = constFoldBinaryOp<IntegerAttr>(
1926 operands, [&](const APInt &a, const APInt &b) {
1927 bounded = b.ule(b.getBitWidth());
1928 return a.lshr(b);
1929 });
1930 return bounded ? result : Attribute();
1931 }
1932
1933 //===----------------------------------------------------------------------===//
1934 // ShRSIOp
1935 //===----------------------------------------------------------------------===//
1936
fold(ArrayRef<Attribute> operands)1937 OpFoldResult arith::ShRSIOp::fold(ArrayRef<Attribute> operands) {
1938 // Don't fold if shifting more than the bit width.
1939 bool bounded = false;
1940 auto result = constFoldBinaryOp<IntegerAttr>(
1941 operands, [&](const APInt &a, const APInt &b) {
1942 bounded = b.ule(b.getBitWidth());
1943 return a.ashr(b);
1944 });
1945 return bounded ? result : Attribute();
1946 }
1947
1948 //===----------------------------------------------------------------------===//
1949 // Atomic Enum
1950 //===----------------------------------------------------------------------===//
1951
1952 /// Returns the identity value attribute associated with an AtomicRMWKind op.
getIdentityValueAttr(AtomicRMWKind kind,Type resultType,OpBuilder & builder,Location loc)1953 Attribute mlir::arith::getIdentityValueAttr(AtomicRMWKind kind, Type resultType,
1954 OpBuilder &builder, Location loc) {
1955 switch (kind) {
1956 case AtomicRMWKind::maxf:
1957 return builder.getFloatAttr(
1958 resultType,
1959 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1960 /*Negative=*/true));
1961 case AtomicRMWKind::addf:
1962 case AtomicRMWKind::addi:
1963 case AtomicRMWKind::maxu:
1964 case AtomicRMWKind::ori:
1965 return builder.getZeroAttr(resultType);
1966 case AtomicRMWKind::andi:
1967 return builder.getIntegerAttr(
1968 resultType,
1969 APInt::getAllOnes(resultType.cast<IntegerType>().getWidth()));
1970 case AtomicRMWKind::maxs:
1971 return builder.getIntegerAttr(
1972 resultType,
1973 APInt::getSignedMinValue(resultType.cast<IntegerType>().getWidth()));
1974 case AtomicRMWKind::minf:
1975 return builder.getFloatAttr(
1976 resultType,
1977 APFloat::getInf(resultType.cast<FloatType>().getFloatSemantics(),
1978 /*Negative=*/false));
1979 case AtomicRMWKind::mins:
1980 return builder.getIntegerAttr(
1981 resultType,
1982 APInt::getSignedMaxValue(resultType.cast<IntegerType>().getWidth()));
1983 case AtomicRMWKind::minu:
1984 return builder.getIntegerAttr(
1985 resultType,
1986 APInt::getMaxValue(resultType.cast<IntegerType>().getWidth()));
1987 case AtomicRMWKind::muli:
1988 return builder.getIntegerAttr(resultType, 1);
1989 case AtomicRMWKind::mulf:
1990 return builder.getFloatAttr(resultType, 1);
1991 // TODO: Add remaining reduction operations.
1992 default:
1993 (void)emitOptionalError(loc, "Reduction operation type not supported");
1994 break;
1995 }
1996 return nullptr;
1997 }
1998
1999 /// Returns the identity value associated with an AtomicRMWKind op.
getIdentityValue(AtomicRMWKind op,Type resultType,OpBuilder & builder,Location loc)2000 Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType,
2001 OpBuilder &builder, Location loc) {
2002 Attribute attr = getIdentityValueAttr(op, resultType, builder, loc);
2003 return builder.create<arith::ConstantOp>(loc, attr);
2004 }
2005
2006 /// Return the value obtained by applying the reduction operation kind
2007 /// associated with a binary AtomicRMWKind op to `lhs` and `rhs`.
getReductionOp(AtomicRMWKind op,OpBuilder & builder,Location loc,Value lhs,Value rhs)2008 Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder,
2009 Location loc, Value lhs, Value rhs) {
2010 switch (op) {
2011 case AtomicRMWKind::addf:
2012 return builder.create<arith::AddFOp>(loc, lhs, rhs);
2013 case AtomicRMWKind::addi:
2014 return builder.create<arith::AddIOp>(loc, lhs, rhs);
2015 case AtomicRMWKind::mulf:
2016 return builder.create<arith::MulFOp>(loc, lhs, rhs);
2017 case AtomicRMWKind::muli:
2018 return builder.create<arith::MulIOp>(loc, lhs, rhs);
2019 case AtomicRMWKind::maxf:
2020 return builder.create<arith::MaxFOp>(loc, lhs, rhs);
2021 case AtomicRMWKind::minf:
2022 return builder.create<arith::MinFOp>(loc, lhs, rhs);
2023 case AtomicRMWKind::maxs:
2024 return builder.create<arith::MaxSIOp>(loc, lhs, rhs);
2025 case AtomicRMWKind::mins:
2026 return builder.create<arith::MinSIOp>(loc, lhs, rhs);
2027 case AtomicRMWKind::maxu:
2028 return builder.create<arith::MaxUIOp>(loc, lhs, rhs);
2029 case AtomicRMWKind::minu:
2030 return builder.create<arith::MinUIOp>(loc, lhs, rhs);
2031 case AtomicRMWKind::ori:
2032 return builder.create<arith::OrIOp>(loc, lhs, rhs);
2033 case AtomicRMWKind::andi:
2034 return builder.create<arith::AndIOp>(loc, lhs, rhs);
2035 // TODO: Add remaining reduction operations.
2036 default:
2037 (void)emitOptionalError(loc, "Reduction operation type not supported");
2038 break;
2039 }
2040 return nullptr;
2041 }
2042
2043 //===----------------------------------------------------------------------===//
2044 // TableGen'd op method definitions
2045 //===----------------------------------------------------------------------===//
2046
2047 #define GET_OP_CLASSES
2048 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.cpp.inc"
2049
2050 //===----------------------------------------------------------------------===//
2051 // TableGen'd enum attribute definitions
2052 //===----------------------------------------------------------------------===//
2053
2054 #include "mlir/Dialect/Arithmetic/IR/ArithmeticOpsEnums.cpp.inc"
2055