1 //===- TensorToSPIRVPass.cpp - Tensor to SPIR-V Passes ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert Tensor dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
14 #include "../PassDetail.h"
15 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
16 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
17 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20 
21 using namespace mlir;
22 
23 namespace {
24 /// A pass converting MLIR Tensor operations into the SPIR-V dialect.
25 class ConvertTensorToSPIRVPass
26     : public ConvertTensorToSPIRVBase<ConvertTensorToSPIRVPass> {
runOnOperation()27   void runOnOperation() override {
28     MLIRContext *context = &getContext();
29     ModuleOp module = getOperation();
30 
31     auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
32     std::unique_ptr<ConversionTarget> target =
33         SPIRVConversionTarget::get(targetAttr);
34 
35     SPIRVTypeConverter::Options options;
36     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
37     SPIRVTypeConverter typeConverter(targetAttr, options);
38 
39     RewritePatternSet patterns(context);
40     arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
41     populateFuncToSPIRVPatterns(typeConverter, patterns);
42     populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64,
43                                   patterns);
44     populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
45 
46     if (failed(applyPartialConversion(module, *target, std::move(patterns))))
47       return signalPassFailure();
48   }
49 };
50 } // namespace
51 
52 std::unique_ptr<OperationPass<ModuleOp>>
createConvertTensorToSPIRVPass()53 mlir::createConvertTensorToSPIRVPass() {
54   return std::make_unique<ConvertTensorToSPIRVPass>();
55 }
56