1 //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements expansion of tanh op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/Math/Transforms/Passes.h"
16 #include "mlir/Dialect/SCF/IR/SCF.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/Transforms/DialectConversion.h"
19 
20 using namespace mlir;
21 
22 /// Expands tanh op into
23 ///   1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0
24 ///   2) exp^{2x}-1 / exp^{2x}+1  , if x < 0
convertTanhOp(math::TanhOp op,PatternRewriter & rewriter)25 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
26   auto floatType = op.getOperand().getType();
27   Location loc = op.getLoc();
28   auto floatOne = rewriter.getFloatAttr(floatType, 1.0);
29   auto floatTwo = rewriter.getFloatAttr(floatType, 2.0);
30   Value one = rewriter.create<arith::ConstantOp>(loc, floatOne);
31   Value two = rewriter.create<arith::ConstantOp>(loc, floatTwo);
32   Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two);
33 
34   // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x}
35   Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX);
36   Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX);
37   Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x);
38   Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x);
39   Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
40 
41   // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1
42   exp2x = rewriter.create<math::ExpOp>(loc, doubledX);
43   dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one);
44   divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one);
45   Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor);
46 
47   // tanh(x) = x >= 0 ? positiveRes : negativeRes
48   auto floatZero = rewriter.getFloatAttr(floatType, 0.0);
49   Value zero = rewriter.create<arith::ConstantOp>(loc, floatZero);
50   Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE,
51                                                 op.getOperand(), zero);
52   rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes,
53                                                negativeRes);
54   return success();
55 }
56 
convertCtlzOp(math::CountLeadingZerosOp op,PatternRewriter & rewriter)57 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
58                                    PatternRewriter &rewriter) {
59   auto operand = op.getOperand();
60   auto elementTy = operand.getType();
61   auto resultTy = op.getType();
62   Location loc = op.getLoc();
63 
64   int bitWidth = elementTy.getIntOrFloatBitWidth();
65   auto zero =
66       rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0));
67   auto leadingZeros = rewriter.create<arith::ConstantOp>(
68       loc, IntegerAttr::get(elementTy, bitWidth));
69 
70   SmallVector<Value> operands = {operand, leadingZeros, zero};
71   SmallVector<Type> types = {elementTy, elementTy, elementTy};
72   SmallVector<Location> locations = {loc, loc, loc};
73 
74   auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
75   Block *before =
76       rewriter.createBlock(&whileOp.getBefore(), {}, types, locations);
77   Block *after =
78       rewriter.createBlock(&whileOp.getAfter(), {}, types, locations);
79 
80   // The conditional block of the while loop.
81   {
82     rewriter.setInsertionPointToStart(&whileOp.getBefore().front());
83     Value input = before->getArgument(0);
84     Value zero = before->getArgument(2);
85 
86     Value inputNotZero = rewriter.create<arith::CmpIOp>(
87         loc, arith::CmpIPredicate::ne, input, zero);
88     rewriter.create<scf::ConditionOp>(loc, inputNotZero,
89                                       before->getArguments());
90   }
91 
92   // The body of the while loop: shift right until reaching a value of 0.
93   {
94     rewriter.setInsertionPointToStart(&whileOp.getAfter().front());
95     Value input = after->getArgument(0);
96     Value leadingZeros = after->getArgument(1);
97 
98     auto one =
99         rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1));
100     auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one);
101     auto leadingZerosMinusOne =
102         rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one);
103 
104     rewriter.create<scf::YieldOp>(
105         loc,
106         ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)}));
107   }
108 
109   rewriter.setInsertionPointAfter(whileOp);
110   rewriter.replaceOp(op, whileOp->getResult(1));
111   return success();
112 }
113 
populateExpandCtlzPattern(RewritePatternSet & patterns)114 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) {
115   patterns.add(convertCtlzOp);
116 }
117 
populateExpandTanhPattern(RewritePatternSet & patterns)118 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) {
119   patterns.add(convertTanhOp);
120 }
121