1 //===- OpToFuncCallLowering.h - GPU ops lowering to custom calls *- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 #ifndef MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 9 #define MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 10 11 #include "mlir/Conversion/LLVMCommon/Pattern.h" 12 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 13 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 14 #include "mlir/IR/Builders.h" 15 16 namespace mlir { 17 18 /// Rewriting that replace SourceOp with a CallOp to `f32Func` or `f64Func` 19 /// depending on the element type that Op operates upon. The function 20 /// declaration is added in case it was not added before. 21 /// 22 /// If the input values are of f16 type, the value is first casted to f32, the 23 /// function called and then the result casted back. 24 /// 25 /// Example with NVVM: 26 /// %exp_f32 = math.exp %arg_f32 : f32 27 /// 28 /// will be transformed into 29 /// llvm.call @__nv_expf(%arg_f32) : (f32) -> f32 30 template <typename SourceOp> 31 struct OpToFuncCallLowering : public ConvertOpToLLVMPattern<SourceOp> { 32 public: OpToFuncCallLoweringOpToFuncCallLowering33 explicit OpToFuncCallLowering(LLVMTypeConverter &lowering, StringRef f32Func, 34 StringRef f64Func) 35 : ConvertOpToLLVMPattern<SourceOp>(lowering), f32Func(f32Func), 36 f64Func(f64Func) {} 37 38 LogicalResult matchAndRewriteOpToFuncCallLowering39 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, 40 ConversionPatternRewriter &rewriter) const override { 41 using LLVM::LLVMFuncOp; 42 43 static_assert( 44 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value, 45 "expected single result op"); 46 47 static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>, 48 SourceOp>::value, 49 "expected op with same operand and result types"); 50 51 SmallVector<Value, 1> castedOperands; 52 for (Value operand : adaptor.getOperands()) 53 castedOperands.push_back(maybeCast(operand, rewriter)); 54 55 Type resultType = castedOperands.front().getType(); 56 Type funcType = getFunctionType(resultType, castedOperands); 57 StringRef funcName = getFunctionName( 58 funcType.cast<LLVM::LLVMFunctionType>().getReturnType()); 59 if (funcName.empty()) 60 return failure(); 61 62 LLVMFuncOp funcOp = appendOrGetFuncOp(funcName, funcType, op); 63 auto callOp = rewriter.create<LLVM::CallOp>( 64 op->getLoc(), resultType, SymbolRefAttr::get(funcOp), castedOperands); 65 66 if (resultType == adaptor.getOperands().front().getType()) { 67 rewriter.replaceOp(op, {callOp.getResult(0)}); 68 return success(); 69 } 70 71 Value truncated = rewriter.create<LLVM::FPTruncOp>( 72 op->getLoc(), adaptor.getOperands().front().getType(), 73 callOp.getResult(0)); 74 rewriter.replaceOp(op, {truncated}); 75 return success(); 76 } 77 78 private: maybeCastOpToFuncCallLowering79 Value maybeCast(Value operand, PatternRewriter &rewriter) const { 80 Type type = operand.getType(); 81 if (!type.isa<Float16Type>()) 82 return operand; 83 84 return rewriter.create<LLVM::FPExtOp>( 85 operand.getLoc(), Float32Type::get(rewriter.getContext()), operand); 86 } 87 getFunctionTypeOpToFuncCallLowering88 Type getFunctionType(Type resultType, ValueRange operands) const { 89 SmallVector<Type> operandTypes(operands.getTypes()); 90 return LLVM::LLVMFunctionType::get(resultType, operandTypes); 91 } 92 getFunctionNameOpToFuncCallLowering93 StringRef getFunctionName(Type type) const { 94 if (type.isa<Float32Type>()) 95 return f32Func; 96 if (type.isa<Float64Type>()) 97 return f64Func; 98 return ""; 99 } 100 appendOrGetFuncOpOpToFuncCallLowering101 LLVM::LLVMFuncOp appendOrGetFuncOp(StringRef funcName, Type funcType, 102 Operation *op) const { 103 using LLVM::LLVMFuncOp; 104 105 auto funcAttr = StringAttr::get(op->getContext(), funcName); 106 Operation *funcOp = SymbolTable::lookupNearestSymbolFrom(op, funcAttr); 107 if (funcOp) 108 return cast<LLVMFuncOp>(*funcOp); 109 110 mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>()); 111 return b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType); 112 } 113 114 const std::string f32Func; 115 const std::string f64Func; 116 }; 117 118 } // namespace mlir 119 120 #endif // MLIR_CONVERSION_GPUCOMMON_OPTOFUNCCALLLOWERING_H_ 121