1 //===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h"
10 #include "PassDetail.h"
11 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15
16 using namespace mlir;
17
18 namespace {
19 // Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one
20 // (conditional) Newton iteration.
21 //
22 // This as accurate as promoting the division to fp32 in the NVPTX backend, but
23 // faster because it performs less Newton iterations, avoids the slow path
24 // for e.g. denormals, and allows reuse of the reciprocal for multiple divisions
25 // by the same divisor.
26 struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> {
27 using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern;
28
29 private:
30 LogicalResult matchAndRewrite(LLVM::FDivOp op,
31 PatternRewriter &rewriter) const override;
32 };
33
34 struct NVVMOptimizeForTarget
35 : public NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> {
36 void runOnOperation() override;
37
getDependentDialects__anon6aae21a00111::NVVMOptimizeForTarget38 void getDependentDialects(DialectRegistry ®istry) const override {
39 registry.insert<NVVM::NVVMDialect>();
40 }
41 };
42 } // namespace
43
matchAndRewrite(LLVM::FDivOp op,PatternRewriter & rewriter) const44 LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op,
45 PatternRewriter &rewriter) const {
46 if (!op.getType().isF16())
47 return rewriter.notifyMatchFailure(op, "not f16");
48 Location loc = op.getLoc();
49
50 Type f32Type = rewriter.getF32Type();
51 Type i32Type = rewriter.getI32Type();
52
53 // Extend lhs and rhs to fp32.
54 Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs());
55 Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs());
56
57 // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp.
58 Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs);
59 Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp);
60
61 // Refine the approximation with one Newton iteration:
62 // float refined = approx + (lhs - approx * rhs) * rcp;
63 Value err = rewriter.create<LLVM::FMAOp>(
64 loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs);
65 Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx);
66
67 // Use refined value if approx is normal (exponent neither all 0 or all 1).
68 Value mask = rewriter.create<LLVM::ConstantOp>(
69 loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000));
70 Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx);
71 Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask);
72 Value zero = rewriter.create<LLVM::ConstantOp>(
73 loc, i32Type, rewriter.getUI32IntegerAttr(0));
74 Value pred = rewriter.create<LLVM::OrOp>(
75 loc,
76 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero),
77 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask));
78 Value result =
79 rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined);
80
81 // Replace with trucation back to fp16.
82 rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result);
83
84 return success();
85 }
86
runOnOperation()87 void NVVMOptimizeForTarget::runOnOperation() {
88 MLIRContext *ctx = getOperation()->getContext();
89 RewritePatternSet patterns(ctx);
90 patterns.add<ExpandDivF16>(ctx);
91 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
92 return signalPassFailure();
93 }
94
createOptimizeForTargetPass()95 std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() {
96 return std::make_unique<NVVMOptimizeForTarget>();
97 }
98