1*3ba66435SRiver Riddle //===- TensorToSPIRVPass.cpp - Tensor to SPIR-V Passes ----------------===//
2*3ba66435SRiver Riddle //
3*3ba66435SRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*3ba66435SRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5*3ba66435SRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*3ba66435SRiver Riddle //
7*3ba66435SRiver Riddle //===----------------------------------------------------------------------===//
8*3ba66435SRiver Riddle //
9*3ba66435SRiver Riddle // This file implements a pass to convert Tensor dialect to SPIR-V dialect.
10*3ba66435SRiver Riddle //
11*3ba66435SRiver Riddle //===----------------------------------------------------------------------===//
12*3ba66435SRiver Riddle 
13*3ba66435SRiver Riddle #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
14*3ba66435SRiver Riddle #include "../PassDetail.h"
15*3ba66435SRiver Riddle #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
16*3ba66435SRiver Riddle #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
17*3ba66435SRiver Riddle #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h"
18*3ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
19*3ba66435SRiver Riddle #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
20*3ba66435SRiver Riddle 
21*3ba66435SRiver Riddle using namespace mlir;
22*3ba66435SRiver Riddle 
23*3ba66435SRiver Riddle namespace {
24*3ba66435SRiver Riddle /// A pass converting MLIR Tensor operations into the SPIR-V dialect.
25*3ba66435SRiver Riddle class ConvertTensorToSPIRVPass
26*3ba66435SRiver Riddle     : public ConvertTensorToSPIRVBase<ConvertTensorToSPIRVPass> {
runOnOperation()27*3ba66435SRiver Riddle   void runOnOperation() override {
28*3ba66435SRiver Riddle     MLIRContext *context = &getContext();
29*3ba66435SRiver Riddle     ModuleOp module = getOperation();
30*3ba66435SRiver Riddle 
31*3ba66435SRiver Riddle     auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
32*3ba66435SRiver Riddle     std::unique_ptr<ConversionTarget> target =
33*3ba66435SRiver Riddle         SPIRVConversionTarget::get(targetAttr);
34*3ba66435SRiver Riddle 
35*3ba66435SRiver Riddle     SPIRVTypeConverter::Options options;
36*3ba66435SRiver Riddle     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
37*3ba66435SRiver Riddle     SPIRVTypeConverter typeConverter(targetAttr, options);
38*3ba66435SRiver Riddle 
39*3ba66435SRiver Riddle     RewritePatternSet patterns(context);
40*3ba66435SRiver Riddle     arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
41*3ba66435SRiver Riddle     populateFuncToSPIRVPatterns(typeConverter, patterns);
42*3ba66435SRiver Riddle     populateTensorToSPIRVPatterns(typeConverter, /*byteCountThreshold=*/64,
43*3ba66435SRiver Riddle                                   patterns);
44*3ba66435SRiver Riddle     populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
45*3ba66435SRiver Riddle 
46*3ba66435SRiver Riddle     if (failed(applyPartialConversion(module, *target, std::move(patterns))))
47*3ba66435SRiver Riddle       return signalPassFailure();
48*3ba66435SRiver Riddle   }
49*3ba66435SRiver Riddle };
50*3ba66435SRiver Riddle } // namespace
51*3ba66435SRiver Riddle 
52*3ba66435SRiver Riddle std::unique_ptr<OperationPass<ModuleOp>>
createConvertTensorToSPIRVPass()53*3ba66435SRiver Riddle mlir::createConvertTensorToSPIRVPass() {
54*3ba66435SRiver Riddle   return std::make_unique<ConvertTensorToSPIRVPass>();
55*3ba66435SRiver Riddle }
56