1 //===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" 10 #include "PassDetail.h" 11 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/PatternMatch.h" 14 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" 15 16 using namespace mlir; 17 18 namespace { 19 // Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one 20 // (conditional) Newton iteration. 21 // 22 // This as accurate as promoting the division to fp32 in the NVPTX backend, but 23 // faster because it performs less Newton iterations, avoids the slow path 24 // for e.g. denormals, and allows reuse of the reciprocal for multiple divisions 25 // by the same divisor. 26 struct ExpandDivF16 : public OpRewritePattern<LLVM::FDivOp> { 27 using OpRewritePattern<LLVM::FDivOp>::OpRewritePattern; 28 29 private: 30 LogicalResult matchAndRewrite(LLVM::FDivOp op, 31 PatternRewriter &rewriter) const override; 32 }; 33 34 struct NVVMOptimizeForTarget 35 : public NVVMOptimizeForTargetBase<NVVMOptimizeForTarget> { 36 void runOnOperation() override; 37 38 void getDependentDialects(DialectRegistry ®istry) const override { 39 registry.insert<NVVM::NVVMDialect>(); 40 } 41 }; 42 } // namespace 43 44 LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op, 45 PatternRewriter &rewriter) const { 46 if (!op.getType().isF16()) 47 return rewriter.notifyMatchFailure(op, "not f16"); 48 Location loc = op.getLoc(); 49 50 Type f32Type = rewriter.getF32Type(); 51 Type i32Type = rewriter.getI32Type(); 52 53 // Extend lhs and rhs to fp32. 54 Value lhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getLhs()); 55 Value rhs = rewriter.create<LLVM::FPExtOp>(loc, f32Type, op.getRhs()); 56 57 // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. 58 Value rcp = rewriter.create<NVVM::RcpApproxFtzF32Op>(loc, f32Type, rhs); 59 Value approx = rewriter.create<LLVM::FMulOp>(loc, lhs, rcp); 60 61 // Refine the approximation with one Newton iteration: 62 // float refined = approx + (lhs - approx * rhs) * rcp; 63 Value err = rewriter.create<LLVM::FMAOp>( 64 loc, approx, rewriter.create<LLVM::FNegOp>(loc, rhs), lhs); 65 Value refined = rewriter.create<LLVM::FMAOp>(loc, err, rcp, approx); 66 67 // Use refined value if approx is normal (exponent neither all 0 or all 1). 68 Value mask = rewriter.create<LLVM::ConstantOp>( 69 loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); 70 Value cast = rewriter.create<LLVM::BitcastOp>(loc, i32Type, approx); 71 Value exp = rewriter.create<LLVM::AndOp>(loc, i32Type, cast, mask); 72 Value zero = rewriter.create<LLVM::ConstantOp>( 73 loc, i32Type, rewriter.getUI32IntegerAttr(0)); 74 Value pred = rewriter.create<LLVM::OrOp>( 75 loc, 76 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, zero), 77 rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, exp, mask)); 78 Value result = 79 rewriter.create<LLVM::SelectOp>(loc, f32Type, pred, approx, refined); 80 81 // Replace with trucation back to fp16. 82 rewriter.replaceOpWithNewOp<LLVM::FPTruncOp>(op, op.getType(), result); 83 84 return success(); 85 } 86 87 void NVVMOptimizeForTarget::runOnOperation() { 88 MLIRContext *ctx = getOperation()->getContext(); 89 RewritePatternSet patterns(ctx); 90 patterns.add<ExpandDivF16>(ctx); 91 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) 92 return signalPassFailure(); 93 } 94 95 std::unique_ptr<Pass> NVVM::createOptimizeForTargetPass() { 96 return std::make_unique<NVVMOptimizeForTarget>(); 97 } 98