1 //===- ExpandTanh.cpp - Code to perform expanding tanh op -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements expansion of tanh op. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/Math/Transforms/Passes.h" 16 #include "mlir/Dialect/SCF/IR/SCF.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/Transforms/DialectConversion.h" 19 20 using namespace mlir; 21 22 /// Expands tanh op into 23 /// 1) 1-exp^{-2x} / 1+exp^{-2x}, if x => 0 24 /// 2) exp^{2x}-1 / exp^{2x}+1 , if x < 0 25 static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) { 26 auto floatType = op.getOperand().getType(); 27 Location loc = op.getLoc(); 28 auto floatOne = rewriter.getFloatAttr(floatType, 1.0); 29 auto floatTwo = rewriter.getFloatAttr(floatType, 2.0); 30 Value one = rewriter.create<arith::ConstantOp>(loc, floatOne); 31 Value two = rewriter.create<arith::ConstantOp>(loc, floatTwo); 32 Value doubledX = rewriter.create<arith::MulFOp>(loc, op.getOperand(), two); 33 34 // Case 1: tanh(x) = 1-exp^{-2x} / 1+exp^{-2x} 35 Value negDoubledX = rewriter.create<arith::NegFOp>(loc, doubledX); 36 Value exp2x = rewriter.create<math::ExpOp>(loc, negDoubledX); 37 Value dividend = rewriter.create<arith::SubFOp>(loc, one, exp2x); 38 Value divisor = rewriter.create<arith::AddFOp>(loc, one, exp2x); 39 Value positiveRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 40 41 // Case 2: tanh(x) = exp^{2x}-1 / exp^{2x}+1 42 exp2x = rewriter.create<math::ExpOp>(loc, doubledX); 43 dividend = rewriter.create<arith::SubFOp>(loc, exp2x, one); 44 divisor = rewriter.create<arith::AddFOp>(loc, exp2x, one); 45 Value negativeRes = rewriter.create<arith::DivFOp>(loc, dividend, divisor); 46 47 // tanh(x) = x >= 0 ? positiveRes : negativeRes 48 auto floatZero = rewriter.getFloatAttr(floatType, 0.0); 49 Value zero = rewriter.create<arith::ConstantOp>(loc, floatZero); 50 Value cmpRes = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OGE, 51 op.getOperand(), zero); 52 rewriter.replaceOpWithNewOp<arith::SelectOp>(op, cmpRes, positiveRes, 53 negativeRes); 54 return success(); 55 } 56 57 static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, 58 PatternRewriter &rewriter) { 59 auto operand = op.getOperand(); 60 auto elementTy = operand.getType(); 61 auto resultTy = op.getType(); 62 Location loc = op.getLoc(); 63 64 int bitWidth = elementTy.getIntOrFloatBitWidth(); 65 auto zero = 66 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 0)); 67 auto leadingZeros = rewriter.create<arith::ConstantOp>( 68 loc, IntegerAttr::get(elementTy, bitWidth)); 69 70 SmallVector<Value> operands = {operand, leadingZeros, zero}; 71 SmallVector<Type> types = {elementTy, elementTy, elementTy}; 72 SmallVector<Location> locations = {loc, loc, loc}; 73 74 auto whileOp = rewriter.create<scf::WhileOp>(loc, types, operands); 75 Block *before = 76 rewriter.createBlock(&whileOp.getBefore(), {}, types, locations); 77 Block *after = 78 rewriter.createBlock(&whileOp.getAfter(), {}, types, locations); 79 80 // The conditional block of the while loop. 81 { 82 rewriter.setInsertionPointToStart(&whileOp.getBefore().front()); 83 Value input = before->getArgument(0); 84 Value zero = before->getArgument(2); 85 86 Value inputNotZero = rewriter.create<arith::CmpIOp>( 87 loc, arith::CmpIPredicate::ne, input, zero); 88 rewriter.create<scf::ConditionOp>(loc, inputNotZero, 89 before->getArguments()); 90 } 91 92 // The body of the while loop: shift right until reaching a value of 0. 93 { 94 rewriter.setInsertionPointToStart(&whileOp.getAfter().front()); 95 Value input = after->getArgument(0); 96 Value leadingZeros = after->getArgument(1); 97 98 auto one = 99 rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(elementTy, 1)); 100 auto shifted = rewriter.create<arith::ShRUIOp>(loc, resultTy, input, one); 101 auto leadingZerosMinusOne = 102 rewriter.create<arith::SubIOp>(loc, resultTy, leadingZeros, one); 103 104 rewriter.create<scf::YieldOp>( 105 loc, 106 ValueRange({shifted, leadingZerosMinusOne, after->getArgument(2)})); 107 } 108 109 rewriter.setInsertionPointAfter(whileOp); 110 rewriter.replaceOp(op, whileOp->getResult(1)); 111 return success(); 112 } 113 114 void mlir::populateExpandCtlzPattern(RewritePatternSet &patterns) { 115 patterns.add(convertCtlzOp); 116 } 117 118 void mlir::populateExpandTanhPattern(RewritePatternSet &patterns) { 119 patterns.add(convertTanhOp); 120 } 121