1*3ba66435SRiver Riddle //===- TensorToSPIRVPass.cpp - Tensor to SPIR-V Passes ----------------===// 2*3ba66435SRiver Riddle // 3*3ba66435SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*3ba66435SRiver Riddle // See https://llvm.org/LICENSE.txt for license information. 5*3ba66435SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*3ba66435SRiver Riddle // 7*3ba66435SRiver Riddle //===----------------------------------------------------------------------===// 8*3ba66435SRiver Riddle // 9*3ba66435SRiver Riddle // This file implements a pass to convert Tensor dialect to SPIR-V dialect. 10*3ba66435SRiver Riddle // 11*3ba66435SRiver Riddle //===----------------------------------------------------------------------===// 12*3ba66435SRiver Riddle 13*3ba66435SRiver Riddle #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h" 14*3ba66435SRiver Riddle #include "../PassDetail.h" 15*3ba66435SRiver Riddle #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" 16*3ba66435SRiver Riddle #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" 17*3ba66435SRiver Riddle #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h" 18*3ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 19*3ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20*3ba66435SRiver Riddle 21*3ba66435SRiver Riddle using namespace mlir; 22*3ba66435SRiver Riddle 23*3ba66435SRiver Riddle namespace { 24*3ba66435SRiver Riddle /// A pass converting MLIR Tensor operations into the SPIR-V dialect. 25*3ba66435SRiver Riddle class ConvertTensorToSPIRVPass 26*3ba66435SRiver Riddle : public ConvertTensorToSPIRVBase<ConvertTensorToSPIRVPass> { runOnOperation()27*3ba66435SRiver Riddle void runOnOperation() override { 28*3ba66435SRiver Riddle MLIRContext *context = &getContext(); 29*3ba66435SRiver Riddle ModuleOp module = getOperation(); 30*3ba66435SRiver Riddle 31*3ba66435SRiver Riddle auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 32*3ba66435SRiver Riddle std::unique_ptr<ConversionTarget> target = 33*3ba66435SRiver Riddle SPIRVConversionTarget::get(targetAttr); 34*3ba66435SRiver Riddle 35*3ba66435SRiver Riddle SPIRVTypeConverter::Options options; 36*3ba66435SRiver Riddle options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; 37*3ba66435SRiver Riddle SPIRVTypeConverter typeConverter(targetAttr, options); 38*3ba66435SRiver Riddle 39*3ba66435SRiver Riddle RewritePatternSet patterns(context); 40*3ba66435SRiver Riddle arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); 41*3ba66435SRiver Riddle populateFuncToSPIRVPatterns(typeConverter, patterns); 42*3ba66435SRiver Riddle populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, 43*3ba66435SRiver Riddle patterns); 44*3ba66435SRiver Riddle populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); 45*3ba66435SRiver Riddle 46*3ba66435SRiver Riddle if (failed(applyPartialConversion(module, *target, std::move(patterns)))) 47*3ba66435SRiver Riddle return signalPassFailure(); 48*3ba66435SRiver Riddle } 49*3ba66435SRiver Riddle }; 50*3ba66435SRiver Riddle } // namespace 51*3ba66435SRiver Riddle 52*3ba66435SRiver Riddle std::unique_ptr<OperationPass<ModuleOp>> createConvertTensorToSPIRVPass()53*3ba66435SRiver Riddlemlir::createConvertTensorToSPIRVPass() { 54*3ba66435SRiver Riddle return std::make_unique<ConvertTensorToSPIRVPass>(); 55*3ba66435SRiver Riddle } 56