1*f3bdb56dSRob Suderman //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===// 2*f3bdb56dSRob Suderman // 3*f3bdb56dSRob Suderman // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*f3bdb56dSRob Suderman // See https://llvm.org/LICENSE.txt for license information. 5*f3bdb56dSRob Suderman // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*f3bdb56dSRob Suderman // 7*f3bdb56dSRob Suderman //===----------------------------------------------------------------------===// 8*f3bdb56dSRob Suderman // 9*f3bdb56dSRob Suderman // This file implements expansion of tanh op. 10*f3bdb56dSRob Suderman // 11*f3bdb56dSRob Suderman //===----------------------------------------------------------------------===// 12*f3bdb56dSRob Suderman 13*f3bdb56dSRob Suderman #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14*f3bdb56dSRob Suderman #include "mlir/Dialect/Math/IR/Math.h" 15*f3bdb56dSRob Suderman #include "mlir/Dialect/Math/Transforms/Passes.h" 16*f3bdb56dSRob Suderman #include "mlir/Dialect/SCF/SCF.h" 17*f3bdb56dSRob Suderman #include "mlir/IR/Builders.h" 18*f3bdb56dSRob Suderman #include "mlir/Transforms/DialectConversion.h" 19*f3bdb56dSRob Suderman 20*f3bdb56dSRob Suderman using namespace mlir; 21*f3bdb56dSRob Suderman 22*f3bdb56dSRob Suderman /// Expands tanh op into 23*f3bdb56dSRob Suderman /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 24*f3bdb56dSRob Suderman /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 25*f3bdb56dSRob Suderman static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { 26*f3bdb56dSRob Suderman auto floatType = op.getOperand().getType(); 27*f3bdb56dSRob Suderman Location loc = op.getLoc(); 28*f3bdb56dSRob Suderman auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 29*f3bdb56dSRob Suderman auto floatTwo = rewriter.getFloatAttr(floatType, 2.0); 30*f3bdb56dSRob Suderman Value one = rewriter.create<arith::ConstantOp>(loc, floatOne); 31*f3bdb56dSRob Suderman Value two = rewriter.create<arith::ConstantOp>(loc, floatTwo); 32*f3bdb56dSRob Suderman Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two); 33*f3bdb56dSRob Suderman 34*f3bdb56dSRob Suderman // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} 35*f3bdb56dSRob Suderman Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX); 36*f3bdb56dSRob Suderman Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX); 37*f3bdb56dSRob Suderman Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x); 38*f3bdb56dSRob Suderman Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x); 39*f3bdb56dSRob Suderman Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 40*f3bdb56dSRob Suderman 41*f3bdb56dSRob Suderman // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 42*f3bdb56dSRob Suderman exp2x = rewriter.create<math::ExpOp>(loc, doubledX); 43*f3bdb56dSRob Suderman dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one); 44*f3bdb56dSRob Suderman divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one); 45*f3bdb56dSRob Suderman Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 46*f3bdb56dSRob Suderman 47*f3bdb56dSRob Suderman // tanh(x) = x >= 0 ? positiveRes : negativeRes 48*f3bdb56dSRob Suderman auto floatZero = rewriter.getFloatAttr(floatType, 0.0); 49*f3bdb56dSRob Suderman Value zero = rewriter.create<arith::ConstantOp>(loc, floatZero); 50*f3bdb56dSRob Suderman Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, 51*f3bdb56dSRob Suderman op.getOperand(), zero); 52*f3bdb56dSRob Suderman rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes, 53*f3bdb56dSRob Suderman negativeRes); 54*f3bdb56dSRob Suderman return success(); 55*f3bdb56dSRob Suderman } 56*f3bdb56dSRob Suderman 57*f3bdb56dSRob Suderman static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, 58*f3bdb56dSRob Suderman PatternRewriter &rewriter) { 59*f3bdb56dSRob Suderman auto operand = op.getOperand(); 60*f3bdb56dSRob Suderman auto elementTy = operand.getType(); 61*f3bdb56dSRob Suderman auto resultTy = op.getType(); 62*f3bdb56dSRob Suderman Location loc = op.getLoc(); 63*f3bdb56dSRob Suderman 64*f3bdb56dSRob Suderman int bitWidth = elementTy.getIntOrFloatBitWidth(); 65*f3bdb56dSRob Suderman auto zero = 66*f3bdb56dSRob Suderman rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); 67*f3bdb56dSRob Suderman auto leadingZeros = rewriter.create<arith::ConstantOp>( 68*f3bdb56dSRob Suderman loc, IntegerAttr::get(elementTy, bitWidth)); 69*f3bdb56dSRob Suderman 70*f3bdb56dSRob Suderman SmallVector<Value> operands = {operand, leadingZeros, zero}; 71*f3bdb56dSRob Suderman SmallVector<Type> types = {elementTy, elementTy, elementTy}; 72*f3bdb56dSRob Suderman SmallVector<Location> locations = {loc, loc, loc}; 73*f3bdb56dSRob Suderman 74*f3bdb56dSRob Suderman auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 75*f3bdb56dSRob Suderman Block *before = 76*f3bdb56dSRob Suderman rewriter.createBlock(&whileOp.getBefore(), {}, types, locations); 77*f3bdb56dSRob Suderman Block *after = 78*f3bdb56dSRob Suderman rewriter.createBlock(&whileOp.getAfter(), {}, types, locations); 79*f3bdb56dSRob Suderman 80*f3bdb56dSRob Suderman // The conditional block of the while loop. 81*f3bdb56dSRob Suderman { 82*f3bdb56dSRob Suderman rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); 83*f3bdb56dSRob Suderman Value input = before->getArgument(0); 84*f3bdb56dSRob Suderman Value zero = before->getArgument(2); 85*f3bdb56dSRob Suderman 86*f3bdb56dSRob Suderman Value inputNotZero = rewriter.create<arith::CmpIOp>( 87*f3bdb56dSRob Suderman loc, arith::CmpIPredicate::ne, input, zero); 88*f3bdb56dSRob Suderman rewriter.create<scf::ConditionOp>(loc, inputNotZero, 89*f3bdb56dSRob Suderman before->getArguments()); 90*f3bdb56dSRob Suderman } 91*f3bdb56dSRob Suderman 92*f3bdb56dSRob Suderman // The body of the while loop: shift right until reaching a value of 0. 93*f3bdb56dSRob Suderman { 94*f3bdb56dSRob Suderman rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); 95*f3bdb56dSRob Suderman Value input = after->getArgument(0); 96*f3bdb56dSRob Suderman Value leadingZeros = after->getArgument(1); 97*f3bdb56dSRob Suderman 98*f3bdb56dSRob Suderman auto one = 99*f3bdb56dSRob Suderman rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1)); 100*f3bdb56dSRob Suderman auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one); 101*f3bdb56dSRob Suderman auto leadingZerosMinusOne = 102*f3bdb56dSRob Suderman rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one); 103*f3bdb56dSRob Suderman 104*f3bdb56dSRob Suderman rewriter.create<scf::YieldOp>( 105*f3bdb56dSRob Suderman loc, 106*f3bdb56dSRob Suderman ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)})); 107*f3bdb56dSRob Suderman } 108*f3bdb56dSRob Suderman 109*f3bdb56dSRob Suderman rewriter.setInsertionPointAfter(whileOp); 110*f3bdb56dSRob Suderman rewriter.replaceOp(op, whileOp->getResult(1)); 111*f3bdb56dSRob Suderman return success(); 112*f3bdb56dSRob Suderman } 113*f3bdb56dSRob Suderman 114*f3bdb56dSRob Suderman void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { 115*f3bdb56dSRob Suderman patterns.add(convertCtlzOp); 116*f3bdb56dSRob Suderman } 117*f3bdb56dSRob Suderman 118*f3bdb56dSRob Suderman void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { 119*f3bdb56dSRob Suderman patterns.add(convertTanhOp); 120*f3bdb56dSRob Suderman } 121