1 //===- TensorToSPIRVPass.cpp - Tensor to SPIR-V Passes ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert Tensor dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h" 14 #include "../PassDetail.h" 15 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h" 16 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" 17 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h" 18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 20 21 using namespace mlir; 22 23 namespace { 24 /// A pass converting MLIR Tensor operations into the SPIR-V dialect. 25 class ConvertTensorToSPIRVPass 26 : public ConvertTensorToSPIRVBase<ConvertTensorToSPIRVPass> { runOnOperation()27 void runOnOperation() override { 28 MLIRContext *context = &getContext(); 29 ModuleOp module = getOperation(); 30 31 auto targetAttr = spirv::lookupTargetEnvOrDefault(module); 32 std::unique_ptr<ConversionTarget> target = 33 SPIRVConversionTarget::get(targetAttr); 34 35 SPIRVTypeConverter::Options options; 36 options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes; 37 SPIRVTypeConverter typeConverter(targetAttr, options); 38 39 RewritePatternSet patterns(context); 40 arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns); 41 populateFuncToSPIRVPatterns(typeConverter, patterns); 42 populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64, 43 patterns); 44 populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); 45 46 if (failed(applyPartialConversion(module, *target, std::move(patterns)))) 47 return signalPassFailure(); 48 } 49 }; 50 } // namespace 51 52 std::unique_ptr<OperationPass<ModuleOp>> createConvertTensorToSPIRVPass()53mlir::createConvertTensorToSPIRVPass() { 54 return std::make_unique<ConvertTensorToSPIRVPass>(); 55 } 56