1f3bdb56dSRob Suderman //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
2f3bdb56dSRob Suderman //
3f3bdb56dSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4f3bdb56dSRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5f3bdb56dSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6f3bdb56dSRob Suderman //
7f3bdb56dSRob Suderman //===----------------------------------------------------------------------===//
8f3bdb56dSRob Suderman //
9f3bdb56dSRob Suderman // This file implements expansion of tanh op.
10f3bdb56dSRob Suderman //
11f3bdb56dSRob Suderman //===----------------------------------------------------------------------===//
12f3bdb56dSRob Suderman
13f3bdb56dSRob Suderman #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14f3bdb56dSRob Suderman #include "mlir/Dialect/Math/IR/Math.h"
15f3bdb56dSRob Suderman #include "mlir/Dialect/Math/Transforms/Passes.h"
16*8b68da2cSAlex Zinenko #include "mlir/Dialect/SCF/IR/SCF.h"
17f3bdb56dSRob Suderman #include "mlir/IR/Builders.h"
18f3bdb56dSRob Suderman #include "mlir/Transforms/DialectConversion.h"
19f3bdb56dSRob Suderman
20f3bdb56dSRob Suderman using namespace mlir;
21f3bdb56dSRob Suderman
22f3bdb56dSRob Suderman /// Expands tanh op into
23f3bdb56dSRob Suderman /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
24f3bdb56dSRob Suderman /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0
convertTanhOp(math::TanhOp op,PatternRewriter & rewriter)25f3bdb56dSRob Suderman static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
26f3bdb56dSRob Suderman auto floatType = op.getOperand().getType();
27f3bdb56dSRob Suderman Location loc = op.getLoc();
28f3bdb56dSRob Suderman auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
29f3bdb56dSRob Suderman auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
30f3bdb56dSRob Suderman Value one = rewriter.create<arith::ConstantOp>(loc, floatOne);
31f3bdb56dSRob Suderman Value two = rewriter.create<arith::ConstantOp>(loc, floatTwo);
32f3bdb56dSRob Suderman Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
33f3bdb56dSRob Suderman
34f3bdb56dSRob Suderman // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
35f3bdb56dSRob Suderman Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX);
36f3bdb56dSRob Suderman Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
37f3bdb56dSRob Suderman Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
38f3bdb56dSRob Suderman Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
39f3bdb56dSRob Suderman Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
40f3bdb56dSRob Suderman
41f3bdb56dSRob Suderman // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
42f3bdb56dSRob Suderman exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
43f3bdb56dSRob Suderman dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
44f3bdb56dSRob Suderman divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
45f3bdb56dSRob Suderman Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
46f3bdb56dSRob Suderman
47f3bdb56dSRob Suderman // tanh(x) = x >= 0 ? positiveRes : negativeRes
48f3bdb56dSRob Suderman auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
49f3bdb56dSRob Suderman Value zero = rewriter.create<arith::ConstantOp>(loc, floatZero);
50f3bdb56dSRob Suderman Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
51f3bdb56dSRob Suderman op.getOperand(), zero);
52f3bdb56dSRob Suderman rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
53f3bdb56dSRob Suderman negativeRes);
54f3bdb56dSRob Suderman return success();
55f3bdb56dSRob Suderman }
56f3bdb56dSRob Suderman
convertCtlzOp(math::CountLeadingZerosOp op,PatternRewriter & rewriter)57f3bdb56dSRob Suderman static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
58f3bdb56dSRob Suderman PatternRewriter &rewriter) {
59f3bdb56dSRob Suderman auto operand = op.getOperand();
60f3bdb56dSRob Suderman auto elementTy = operand.getType();
61f3bdb56dSRob Suderman auto resultTy = op.getType();
62f3bdb56dSRob Suderman Location loc = op.getLoc();
63f3bdb56dSRob Suderman
64f3bdb56dSRob Suderman int bitWidth = elementTy.getIntOrFloatBitWidth();
65f3bdb56dSRob Suderman auto zero =
66f3bdb56dSRob Suderman rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
67f3bdb56dSRob Suderman auto leadingZeros = rewriter.create<arith::ConstantOp>(
68f3bdb56dSRob Suderman loc, IntegerAttr::get(elementTy, bitWidth));
69f3bdb56dSRob Suderman
70f3bdb56dSRob Suderman SmallVector<Value> operands = {operand, leadingZeros, zero};
71f3bdb56dSRob Suderman SmallVector<Type> types = {elementTy, elementTy, elementTy};
72f3bdb56dSRob Suderman SmallVector<Location> locations = {loc, loc, loc};
73f3bdb56dSRob Suderman
74f3bdb56dSRob Suderman auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
75f3bdb56dSRob Suderman Block *before =
76f3bdb56dSRob Suderman rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
77f3bdb56dSRob Suderman Block *after =
78f3bdb56dSRob Suderman rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
79f3bdb56dSRob Suderman
80f3bdb56dSRob Suderman // The conditional block of the while loop.
81f3bdb56dSRob Suderman {
82f3bdb56dSRob Suderman rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
83f3bdb56dSRob Suderman Value input = before->getArgument(0);
84f3bdb56dSRob Suderman Value zero = before->getArgument(2);
85f3bdb56dSRob Suderman
86f3bdb56dSRob Suderman Value inputNotZero = rewriter.create<arith::CmpIOp>(
87f3bdb56dSRob Suderman loc, arith::CmpIPredicate::ne, input, zero);
88f3bdb56dSRob Suderman rewriter.create<scf::ConditionOp>(loc, inputNotZero,
89f3bdb56dSRob Suderman before->getArguments());
90f3bdb56dSRob Suderman }
91f3bdb56dSRob Suderman
92f3bdb56dSRob Suderman // The body of the while loop: shift right until reaching a value of 0.
93f3bdb56dSRob Suderman {
94f3bdb56dSRob Suderman rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
95f3bdb56dSRob Suderman Value input = after->getArgument(0);
96f3bdb56dSRob Suderman Value leadingZeros = after->getArgument(1);
97f3bdb56dSRob Suderman
98f3bdb56dSRob Suderman auto one =
99f3bdb56dSRob Suderman rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
100f3bdb56dSRob Suderman auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
101f3bdb56dSRob Suderman auto leadingZerosMinusOne =
102f3bdb56dSRob Suderman rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
103f3bdb56dSRob Suderman
104f3bdb56dSRob Suderman rewriter.create<scf::YieldOp>(
105f3bdb56dSRob Suderman loc,
106f3bdb56dSRob Suderman ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
107f3bdb56dSRob Suderman }
108f3bdb56dSRob Suderman
109f3bdb56dSRob Suderman rewriter.setInsertionPointAfter(whileOp);
110f3bdb56dSRob Suderman rewriter.replaceOp(op, whileOp->getResult(1));
111f3bdb56dSRob Suderman return success();
112f3bdb56dSRob Suderman }
113f3bdb56dSRob Suderman
populateExpandCtlzPattern(RewritePatternSet & patterns)114f3bdb56dSRob Suderman void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
115f3bdb56dSRob Suderman patterns.add(convertCtlzOp);
116f3bdb56dSRob Suderman }
117f3bdb56dSRob Suderman
populateExpandTanhPattern(RewritePatternSet & patterns)118f3bdb56dSRob Suderman void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
119f3bdb56dSRob Suderman patterns.add(convertTanhOp);
120f3bdb56dSRob Suderman }
121