1*400fef08SChristian Sigg //===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===//
2*400fef08SChristian Sigg //
3*400fef08SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*400fef08SChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
5*400fef08SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*400fef08SChristian Sigg //
7*400fef08SChristian Sigg //===----------------------------------------------------------------------===//
8*400fef08SChristian Sigg
9*400fef08SChristian Sigg #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
10*400fef08SChristian Sigg #include "PassDetail.h"
11*400fef08SChristian Sigg #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
12*400fef08SChristian Sigg #include "mlir/IR/Builders.h"
13*400fef08SChristian Sigg #include "mlir/IR/PatternMatch.h"
14*400fef08SChristian Sigg #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15*400fef08SChristian Sigg
16*400fef08SChristian Sigg using namespace mlir;
17*400fef08SChristian Sigg
18*400fef08SChristian Sigg namespace {
19*400fef08SChristian Sigg // Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one
20*400fef08SChristian Sigg // (conditional) Newton iteration.
21*400fef08SChristian Sigg //
22*400fef08SChristian Sigg // This as accurate as promoting the division to fp32 in the NVPTX backend, but
23*400fef08SChristian Sigg // faster because it performs less Newton iterations, avoids the slow path
24*400fef08SChristian Sigg // for e.g. denormals, and allows reuse of the reciprocal for multiple divisions
25*400fef08SChristian Sigg // by the same divisor.
26*400fef08SChristian Sigg struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
27*400fef08SChristian Sigg using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern;
28*400fef08SChristian Sigg
29*400fef08SChristian Sigg private:
30*400fef08SChristian Sigg LogicalResult matchAndRewrite(LLVM::FDivOp op,
31*400fef08SChristian Sigg PatternRewriter &rewriter) const override;
32*400fef08SChristian Sigg };
33*400fef08SChristian Sigg
34*400fef08SChristian Sigg struct NVVMOptimizeForTarget
35*400fef08SChristian Sigg : public NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
36*400fef08SChristian Sigg void runOnOperation() override;
37*400fef08SChristian Sigg
getDependentDialects__anon6aae21a00111::NVVMOptimizeForTarget38*400fef08SChristian Sigg void getDependentDialects(DialectRegistry ®istry) const override {
39*400fef08SChristian Sigg registry.insert<NVVM::NVVMDialect>();
40*400fef08SChristian Sigg }
41*400fef08SChristian Sigg };
42*400fef08SChristian Sigg } // namespace
43*400fef08SChristian Sigg
matchAndRewrite(LLVM::FDivOp op,PatternRewriter & rewriter) const44*400fef08SChristian Sigg LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
45*400fef08SChristian Sigg PatternRewriter &rewriter) const {
46*400fef08SChristian Sigg if (!op.getType().isF16())
47*400fef08SChristian Sigg return rewriter.notifyMatchFailure(op, "not f16");
48*400fef08SChristian Sigg Location loc = op.getLoc();
49*400fef08SChristian Sigg
50*400fef08SChristian Sigg Type f32Type = rewriter.getF32Type();
51*400fef08SChristian Sigg Type i32Type = rewriter.getI32Type();
52*400fef08SChristian Sigg
53*400fef08SChristian Sigg // Extend lhs and rhs to fp32.
54*400fef08SChristian Sigg Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs());
55*400fef08SChristian Sigg Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs());
56*400fef08SChristian Sigg
57*400fef08SChristian Sigg // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp.
58*400fef08SChristian Sigg Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs);
59*400fef08SChristian Sigg Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp);
60*400fef08SChristian Sigg
61*400fef08SChristian Sigg // Refine the approximation with one Newton iteration:
62*400fef08SChristian Sigg // float refined = approx + (lhs - approx * rhs) * rcp;
63*400fef08SChristian Sigg Value err = rewriter.create<LLVM::FMAOp>(
64*400fef08SChristian Sigg loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs);
65*400fef08SChristian Sigg Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx);
66*400fef08SChristian Sigg
67*400fef08SChristian Sigg // Use refined value if approx is normal (exponent neither all 0 or all 1).
68*400fef08SChristian Sigg Value mask = rewriter.create<LLVM::ConstantOp>(
69*400fef08SChristian Sigg loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000));
70*400fef08SChristian Sigg Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx);
71*400fef08SChristian Sigg Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask);
72*400fef08SChristian Sigg Value zero = rewriter.create<LLVM::ConstantOp>(
73*400fef08SChristian Sigg loc, i32Type, rewriter.getUI32IntegerAttr(0));
74*400fef08SChristian Sigg Value pred = rewriter.create<LLVM::OrOp>(
75*400fef08SChristian Sigg loc,
76*400fef08SChristian Sigg rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero),
77*400fef08SChristian Sigg rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask));
78*400fef08SChristian Sigg Value result =
79*400fef08SChristian Sigg rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined);
80*400fef08SChristian Sigg
81*400fef08SChristian Sigg // Replace with trucation back to fp16.
82*400fef08SChristian Sigg rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result);
83*400fef08SChristian Sigg
84*400fef08SChristian Sigg return success();
85*400fef08SChristian Sigg }
86*400fef08SChristian Sigg
runOnOperation()87*400fef08SChristian Sigg void NVVMOptimizeForTarget::runOnOperation() {
88*400fef08SChristian Sigg MLIRContext *ctx = getOperation()->getContext();
89*400fef08SChristian Sigg RewritePatternSet patterns(ctx);
90*400fef08SChristian Sigg patterns.add<ExpandDivF16>(ctx);
91*400fef08SChristian Sigg if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
92*400fef08SChristian Sigg return signalPassFailure();
93*400fef08SChristian Sigg }
94*400fef08SChristian Sigg
createOptimizeForTargetPass()95*400fef08SChristian Sigg std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() {
96*400fef08SChristian Sigg return std::make_unique<NVVMOptimizeForTarget>();
97*400fef08SChristian Sigg }
98