1*f3bdb56dSRob Suderman //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
2*f3bdb56dSRob Suderman //
3*f3bdb56dSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*f3bdb56dSRob Suderman // See https://llvm.org/LICENSE.txt for license information.
5*f3bdb56dSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*f3bdb56dSRob Suderman //
7*f3bdb56dSRob Suderman //===----------------------------------------------------------------------===//
8*f3bdb56dSRob Suderman //
9*f3bdb56dSRob Suderman // This file implements expansion of tanh op.
10*f3bdb56dSRob Suderman //
11*f3bdb56dSRob Suderman //===----------------------------------------------------------------------===//
12*f3bdb56dSRob Suderman 
13*f3bdb56dSRob Suderman #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14*f3bdb56dSRob Suderman #include "mlir/Dialect/Math/IR/Math.h"
15*f3bdb56dSRob Suderman #include "mlir/Dialect/Math/Transforms/Passes.h"
16*f3bdb56dSRob Suderman #include "mlir/Dialect/SCF/SCF.h"
17*f3bdb56dSRob Suderman #include "mlir/IR/Builders.h"
18*f3bdb56dSRob Suderman #include "mlir/Transforms/DialectConversion.h"
19*f3bdb56dSRob Suderman 
20*f3bdb56dSRob Suderman using namespace mlir;
21*f3bdb56dSRob Suderman 
22*f3bdb56dSRob Suderman /// Expands tanh op into
23*f3bdb56dSRob Suderman ///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
24*f3bdb56dSRob Suderman ///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
25*f3bdb56dSRob Suderman static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
26*f3bdb56dSRob Suderman   auto floatType = op.getOperand().getType();
27*f3bdb56dSRob Suderman   Location loc = op.getLoc();
28*f3bdb56dSRob Suderman   auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
29*f3bdb56dSRob Suderman   auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
30*f3bdb56dSRob Suderman   Value one = rewriter.create<arith::ConstantOp>(loc, floatOne);
31*f3bdb56dSRob Suderman   Value two = rewriter.create<arith::ConstantOp>(loc, floatTwo);
32*f3bdb56dSRob Suderman   Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
33*f3bdb56dSRob Suderman 
34*f3bdb56dSRob Suderman   // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
35*f3bdb56dSRob Suderman   Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX);
36*f3bdb56dSRob Suderman   Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
37*f3bdb56dSRob Suderman   Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
38*f3bdb56dSRob Suderman   Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
39*f3bdb56dSRob Suderman   Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
40*f3bdb56dSRob Suderman 
41*f3bdb56dSRob Suderman   // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
42*f3bdb56dSRob Suderman   exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
43*f3bdb56dSRob Suderman   dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
44*f3bdb56dSRob Suderman   divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
45*f3bdb56dSRob Suderman   Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
46*f3bdb56dSRob Suderman 
47*f3bdb56dSRob Suderman   // tanh(x) = x >= 0 ? positiveRes : negativeRes
48*f3bdb56dSRob Suderman   auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
49*f3bdb56dSRob Suderman   Value zero = rewriter.create<arith::ConstantOp>(loc, floatZero);
50*f3bdb56dSRob Suderman   Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
51*f3bdb56dSRob Suderman                                                 op.getOperand(), zero);
52*f3bdb56dSRob Suderman   rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
53*f3bdb56dSRob Suderman                                                negativeRes);
54*f3bdb56dSRob Suderman   return success();
55*f3bdb56dSRob Suderman }
56*f3bdb56dSRob Suderman 
57*f3bdb56dSRob Suderman static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
58*f3bdb56dSRob Suderman                                    PatternRewriter &rewriter) {
59*f3bdb56dSRob Suderman   auto operand = op.getOperand();
60*f3bdb56dSRob Suderman   auto elementTy = operand.getType();
61*f3bdb56dSRob Suderman   auto resultTy = op.getType();
62*f3bdb56dSRob Suderman   Location loc = op.getLoc();
63*f3bdb56dSRob Suderman 
64*f3bdb56dSRob Suderman   int bitWidth = elementTy.getIntOrFloatBitWidth();
65*f3bdb56dSRob Suderman   auto zero =
66*f3bdb56dSRob Suderman       rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
67*f3bdb56dSRob Suderman   auto leadingZeros = rewriter.create<arith::ConstantOp>(
68*f3bdb56dSRob Suderman       loc, IntegerAttr::get(elementTy, bitWidth));
69*f3bdb56dSRob Suderman 
70*f3bdb56dSRob Suderman   SmallVector<Value> operands = {operand, leadingZeros, zero};
71*f3bdb56dSRob Suderman   SmallVector<Type> types = {elementTy, elementTy, elementTy};
72*f3bdb56dSRob Suderman   SmallVector<Location> locations = {loc, loc, loc};
73*f3bdb56dSRob Suderman 
74*f3bdb56dSRob Suderman   auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
75*f3bdb56dSRob Suderman   Block *before =
76*f3bdb56dSRob Suderman       rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
77*f3bdb56dSRob Suderman   Block *after =
78*f3bdb56dSRob Suderman       rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
79*f3bdb56dSRob Suderman 
80*f3bdb56dSRob Suderman   // The conditional block of the while loop.
81*f3bdb56dSRob Suderman   {
82*f3bdb56dSRob Suderman     rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
83*f3bdb56dSRob Suderman     Value input = before->getArgument(0);
84*f3bdb56dSRob Suderman     Value zero = before->getArgument(2);
85*f3bdb56dSRob Suderman 
86*f3bdb56dSRob Suderman     Value inputNotZero = rewriter.create<arith::CmpIOp>(
87*f3bdb56dSRob Suderman         loc, arith::CmpIPredicate::ne, input, zero);
88*f3bdb56dSRob Suderman     rewriter.create<scf::ConditionOp>(loc, inputNotZero,
89*f3bdb56dSRob Suderman                                       before->getArguments());
90*f3bdb56dSRob Suderman   }
91*f3bdb56dSRob Suderman 
92*f3bdb56dSRob Suderman   // The body of the while loop: shift right until reaching a value of 0.
93*f3bdb56dSRob Suderman   {
94*f3bdb56dSRob Suderman     rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
95*f3bdb56dSRob Suderman     Value input = after->getArgument(0);
96*f3bdb56dSRob Suderman     Value leadingZeros = after->getArgument(1);
97*f3bdb56dSRob Suderman 
98*f3bdb56dSRob Suderman     auto one =
99*f3bdb56dSRob Suderman         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
100*f3bdb56dSRob Suderman     auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
101*f3bdb56dSRob Suderman     auto leadingZerosMinusOne =
102*f3bdb56dSRob Suderman         rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
103*f3bdb56dSRob Suderman 
104*f3bdb56dSRob Suderman     rewriter.create<scf::YieldOp>(
105*f3bdb56dSRob Suderman         loc,
106*f3bdb56dSRob Suderman         ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
107*f3bdb56dSRob Suderman   }
108*f3bdb56dSRob Suderman 
109*f3bdb56dSRob Suderman   rewriter.setInsertionPointAfter(whileOp);
110*f3bdb56dSRob Suderman   rewriter.replaceOp(op, whileOp->getResult(1));
111*f3bdb56dSRob Suderman   return success();
112*f3bdb56dSRob Suderman }
113*f3bdb56dSRob Suderman 
114*f3bdb56dSRob Suderman void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
115*f3bdb56dSRob Suderman   patterns.add(convertCtlzOp);
116*f3bdb56dSRob Suderman }
117*f3bdb56dSRob Suderman 
118*f3bdb56dSRob Suderman void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
119*f3bdb56dSRob Suderman   patterns.add(convertTanhOp);
120*f3bdb56dSRob Suderman }
121