1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 10 11 #include "mlir/Conversion/LLVMCommon/Pattern.h" 12 #include "mlir/Dialect/GPU/GPUDialect.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/Dialect/StandardOps/IR/Ops.h" 15 #include "mlir/IR/Builders.h" 16 17 namespace mlir { 18 19 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` 20 /// depending on the element type that Op operates upon. The function 21 /// declaration is added in case it was not added before. 22 /// 23 /// If the input values are of f16 type, the value is first casted to f32, the 24 /// function called and then the result casted back. 25 /// 26 /// Example with NVVM: 27 /// %exp_f32 = std.exp %arg_f32 : f32 28 /// 29 /// will be transformed into 30 /// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32 31 template <typename SourceOp> 32 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> { 33 public: 34 explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func, 35 StringRef f64Func) 36 : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func), 37 f64Func(f64Func) {} 38 39 LogicalResult 40 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 41 ConversionPatternRewriter &rewriter) const override { 42 using LLVM::LLVMFuncOp; 43 44 static_assert( 45 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, 46 "expected single result op"); 47 48 static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, 49 SourceOp>::value, 50 "expected op with same operand and result types"); 51 52 SmallVector<Value, 1> castedOperands; 53 for (Value operand : adaptor.getOperands()) 54 castedOperands.push_back(maybeCast(operand, rewriter)); 55 56 Type resultType = castedOperands.front().getType(); 57 Type funcType = getFunctionType(resultType, castedOperands); 58 StringRef funcName = getFunctionName( 59 funcType.cast<LLVM::LLVMFunctionType>().getReturnType()); 60 if (funcName.empty()) 61 return failure(); 62 63 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); 64 auto callOp = rewriter.create<LLVM::CallOp>( 65 op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands); 66 67 if (resultType == adaptor.getOperands().front().getType()) { 68 rewriter.replaceOp(op, {callOp.getResult(0)}); 69 return success(); 70 } 71 72 Value truncated = rewriter.create<LLVM::FPTruncOp>( 73 op->getLoc(), adaptor.getOperands().front().getType(), 74 callOp.getResult(0)); 75 rewriter.replaceOp(op, {truncated}); 76 return success(); 77 } 78 79 private: 80 Value maybeCast(Value operand, PatternRewriter &rewriter) const { 81 Type type = operand.getType(); 82 if (!type.isa<Float16Type>()) 83 return operand; 84 85 return rewriter.create<LLVM::FPExtOp>( 86 operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); 87 } 88 89 Type getFunctionType(Type resultType, ValueRange operands) const { 90 SmallVector<Type> operandTypes(operands.getTypes()); 91 return LLVM::LLVMFunctionType::get(resultType, operandTypes); 92 } 93 94 StringRef getFunctionName(Type type) const { 95 if (type.isa<Float32Type>()) 96 return f32Func; 97 if (type.isa<Float64Type>()) 98 return f64Func; 99 return ""; 100 } 101 102 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, 103 Operation *op) const { 104 using LLVM::LLVMFuncOp; 105 106 auto funcAttr = StringAttr::get(op->getContext(), funcName); 107 Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); 108 if (funcOp) 109 return cast<LLVMFuncOp>(*funcOp); 110 111 mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>()); 112 return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType); 113 } 114 115 const std::string f32Func; 116 const std::string f64Func; 117 }; 118 119 } // namespace mlir 120 121 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 122