1 //===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15 
16 using namespace mlir;
17 
18 namespace {
19 // Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one
20 // (conditional) Newton iteration.
21 //
22 // This as accurate as promoting the division to fp32 in the NVPTX backend, but
23 // faster because it performs less Newton iterations, avoids the slow path
24 // for e.g. denormals, and allows reuse of the reciprocal for multiple divisions
25 // by the same divisor.
26 struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
27   using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern;
28 
29 private:
30   LogicalResult matchAndRewrite(LLVM::FDivOp op,
31                                 PatternRewriter &rewriter) const override;
32 };
33 
34 struct NVVMOptimizeForTarget
35     : public NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
36   void runOnOperation() override;
37 
getDependentDialects__anon6aae21a00111::NVVMOptimizeForTarget38   void getDependentDialects(DialectRegistry &registry) const override {
39     registry.insert<NVVM::NVVMDialect>();
40   }
41 };
42 } // namespace
43 
matchAndRewrite(LLVM::FDivOp op,PatternRewriter & rewriter) const44 LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
45                                             PatternRewriter &rewriter) const {
46   if (!op.getType().isF16())
47     return rewriter.notifyMatchFailure(op, "not f16");
48   Location loc = op.getLoc();
49 
50   Type f32Type = rewriter.getF32Type();
51   Type i32Type = rewriter.getI32Type();
52 
53   // Extend lhs and rhs to fp32.
54   Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs());
55   Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs());
56 
57   // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp.
58   Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs);
59   Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp);
60 
61   // Refine the approximation with one Newton iteration:
62   // float refined = approx + (lhs - approx * rhs) * rcp;
63   Value err = rewriter.create<LLVM::FMAOp>(
64       loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs);
65   Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx);
66 
67   // Use refined value if approx is normal (exponent neither all 0 or all 1).
68   Value mask = rewriter.create<LLVM::ConstantOp>(
69       loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000));
70   Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx);
71   Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask);
72   Value zero = rewriter.create<LLVM::ConstantOp>(
73       loc, i32Type, rewriter.getUI32IntegerAttr(0));
74   Value pred = rewriter.create<LLVM::OrOp>(
75       loc,
76       rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero),
77       rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask));
78   Value result =
79       rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined);
80 
81   // Replace with trucation back to fp16.
82   rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result);
83 
84   return success();
85 }
86 
runOnOperation()87 void NVVMOptimizeForTarget::runOnOperation() {
88   MLIRContext *ctx = getOperation()->getContext();
89   RewritePatternSet patterns(ctx);
90   patterns.add<ExpandDivF16>(ctx);
91   if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
92     return signalPassFailure();
93 }
94 
createOptimizeForTargetPass()95 std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() {
96   return std::make_unique<NVVMOptimizeForTarget>();
97 }
98