1*400fef08SChristian Sigg //===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===//
2*400fef08SChristian Sigg //
3*400fef08SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*400fef08SChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
5*400fef08SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*400fef08SChristian Sigg //
7*400fef08SChristian Sigg //===----------------------------------------------------------------------===//
8*400fef08SChristian Sigg 
9*400fef08SChristian Sigg #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
10*400fef08SChristian Sigg #include "PassDetail.h"
11*400fef08SChristian Sigg #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
12*400fef08SChristian Sigg #include "mlir/IR/Builders.h"
13*400fef08SChristian Sigg #include "mlir/IR/PatternMatch.h"
14*400fef08SChristian Sigg #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15*400fef08SChristian Sigg 
16*400fef08SChristian Sigg using namespace mlir;
17*400fef08SChristian Sigg 
18*400fef08SChristian Sigg namespace {
19*400fef08SChristian Sigg // Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one
20*400fef08SChristian Sigg // (conditional) Newton iteration.
21*400fef08SChristian Sigg //
22*400fef08SChristian Sigg // This as accurate as promoting the division to fp32 in the NVPTX backend, but
23*400fef08SChristian Sigg // faster because it performs less Newton iterations, avoids the slow path
24*400fef08SChristian Sigg // for e.g. denormals, and allows reuse of the reciprocal for multiple divisions
25*400fef08SChristian Sigg // by the same divisor.
26*400fef08SChristian Sigg struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
27*400fef08SChristian Sigg   using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern;
28*400fef08SChristian Sigg 
29*400fef08SChristian Sigg private:
30*400fef08SChristian Sigg   LogicalResult matchAndRewrite(LLVM::FDivOp op,
31*400fef08SChristian Sigg                                 PatternRewriter &rewriter) const override;
32*400fef08SChristian Sigg };
33*400fef08SChristian Sigg 
34*400fef08SChristian Sigg struct NVVMOptimizeForTarget
35*400fef08SChristian Sigg     : public NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
36*400fef08SChristian Sigg   void runOnOperation() override;
37*400fef08SChristian Sigg 
getDependentDialects__anon6aae21a00111::NVVMOptimizeForTarget38*400fef08SChristian Sigg   void getDependentDialects(DialectRegistry &registry) const override {
39*400fef08SChristian Sigg     registry.insert<NVVM::NVVMDialect>();
40*400fef08SChristian Sigg   }
41*400fef08SChristian Sigg };
42*400fef08SChristian Sigg } // namespace
43*400fef08SChristian Sigg 
matchAndRewrite(LLVM::FDivOp op,PatternRewriter & rewriter) const44*400fef08SChristian Sigg LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
45*400fef08SChristian Sigg                                             PatternRewriter &rewriter) const {
46*400fef08SChristian Sigg   if (!op.getType().isF16())
47*400fef08SChristian Sigg     return rewriter.notifyMatchFailure(op, "not f16");
48*400fef08SChristian Sigg   Location loc = op.getLoc();
49*400fef08SChristian Sigg 
50*400fef08SChristian Sigg   Type f32Type = rewriter.getF32Type();
51*400fef08SChristian Sigg   Type i32Type = rewriter.getI32Type();
52*400fef08SChristian Sigg 
53*400fef08SChristian Sigg   // Extend lhs and rhs to fp32.
54*400fef08SChristian Sigg   Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs());
55*400fef08SChristian Sigg   Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs());
56*400fef08SChristian Sigg 
57*400fef08SChristian Sigg   // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp.
58*400fef08SChristian Sigg   Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs);
59*400fef08SChristian Sigg   Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp);
60*400fef08SChristian Sigg 
61*400fef08SChristian Sigg   // Refine the approximation with one Newton iteration:
62*400fef08SChristian Sigg   // float refined = approx + (lhs - approx * rhs) * rcp;
63*400fef08SChristian Sigg   Value err = rewriter.create<LLVM::FMAOp>(
64*400fef08SChristian Sigg       loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs);
65*400fef08SChristian Sigg   Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx);
66*400fef08SChristian Sigg 
67*400fef08SChristian Sigg   // Use refined value if approx is normal (exponent neither all 0 or all 1).
68*400fef08SChristian Sigg   Value mask = rewriter.create<LLVM::ConstantOp>(
69*400fef08SChristian Sigg       loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000));
70*400fef08SChristian Sigg   Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx);
71*400fef08SChristian Sigg   Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask);
72*400fef08SChristian Sigg   Value zero = rewriter.create<LLVM::ConstantOp>(
73*400fef08SChristian Sigg       loc, i32Type, rewriter.getUI32IntegerAttr(0));
74*400fef08SChristian Sigg   Value pred = rewriter.create<LLVM::OrOp>(
75*400fef08SChristian Sigg       loc,
76*400fef08SChristian Sigg       rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero),
77*400fef08SChristian Sigg       rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask));
78*400fef08SChristian Sigg   Value result =
79*400fef08SChristian Sigg       rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined);
80*400fef08SChristian Sigg 
81*400fef08SChristian Sigg   // Replace with trucation back to fp16.
82*400fef08SChristian Sigg   rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result);
83*400fef08SChristian Sigg 
84*400fef08SChristian Sigg   return success();
85*400fef08SChristian Sigg }
86*400fef08SChristian Sigg 
runOnOperation()87*400fef08SChristian Sigg void NVVMOptimizeForTarget::runOnOperation() {
88*400fef08SChristian Sigg   MLIRContext *ctx = getOperation()->getContext();
89*400fef08SChristian Sigg   RewritePatternSet patterns(ctx);
90*400fef08SChristian Sigg   patterns.add<ExpandDivF16>(ctx);
91*400fef08SChristian Sigg   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
92*400fef08SChristian Sigg     return signalPassFailure();
93*400fef08SChristian Sigg }
94*400fef08SChristian Sigg 
createOptimizeForTargetPass()95*400fef08SChristian Sigg std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() {
96*400fef08SChristian Sigg   return std::make_unique<NVVMOptimizeForTarget>();
97*400fef08SChristian Sigg }
98