1 //===- TensorToSPIRV.cpp - Tensor to SPIR-V Patterns ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Tensor dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Conversion/TensorToSPIRV/TensorToSPIRV.h"
14 #include "../SPIRVCommon/Pattern.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h"
19 #include "mlir/Dialect/Tensor/IR/Tensor.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/Support/LogicalResult.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/Support/Debug.h"
24 
25 #define DEBUG_TYPE "tensor-to-spirv-pattern"
26 
27 using namespace mlir;
28 
29 //===----------------------------------------------------------------------===//
30 // Operation conversion
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 
35 /// Converts tensor.extract into loading using access chains from SPIR-V local
36 /// variables.
37 class TensorExtractPattern final
38     : public OpConversionPattern<tensor::ExtractOp> {
39 public:
TensorExtractPattern(TypeConverter & typeConverter,MLIRContext * context,int64_t threshold,PatternBenefit benefit=1)40   TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context,
41                        int64_t threshold, PatternBenefit benefit = 1)
42       : OpConversionPattern(typeConverter, context, benefit),
43         byteCountThreshold(threshold) {}
44 
45   LogicalResult
matchAndRewrite(tensor::ExtractOp extractOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const46   matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor,
47                   ConversionPatternRewriter &rewriter) const override {
48     TensorType tensorType = extractOp.getTensor().getType().cast<TensorType>();
49 
50     if (!tensorType.hasStaticShape())
51       return rewriter.notifyMatchFailure(extractOp, "non-static tensor");
52 
53     if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() >
54         byteCountThreshold * 8)
55       return rewriter.notifyMatchFailure(extractOp,
56                                          "exceeding byte count threshold");
57 
58     Location loc = extractOp.getLoc();
59 
60     int64_t rank = tensorType.getRank();
61     SmallVector<int64_t, 4> strides(rank, 1);
62     for (int i = rank - 2; i >= 0; --i) {
63       strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1);
64     }
65 
66     Type varType = spirv::PointerType::get(adaptor.getTensor().getType(),
67                                            spirv::StorageClass::Function);
68 
69     spirv::VariableOp varOp;
70     if (adaptor.getTensor().getDefiningOp<spirv::ConstantOp>()) {
71       varOp = rewriter.create<spirv::VariableOp>(
72           loc, varType, spirv::StorageClass::Function,
73           /*initializer=*/adaptor.getTensor());
74     } else {
75       // Need to store the value to the local variable. It's questionable
76       // whether we want to support such case though.
77       return failure();
78     }
79 
80     auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
81     auto indexType = typeConverter.getIndexType();
82 
83     Value index = spirv::linearizeIndex(adaptor.getIndices(), strides,
84                                         /*offset=*/0, indexType, loc, rewriter);
85     auto acOp = rewriter.create<spirv::AccessChainOp>(loc, varOp, index);
86 
87     rewriter.replaceOpWithNewOp<spirv::LoadOp>(extractOp, acOp);
88 
89     return success();
90   }
91 
92 private:
93   int64_t byteCountThreshold;
94 };
95 
96 } // namespace
97 
98 //===----------------------------------------------------------------------===//
99 // Pattern population
100 //===----------------------------------------------------------------------===//
101 
populateTensorToSPIRVPatterns(SPIRVTypeConverter & typeConverter,int64_t byteCountThreshold,RewritePatternSet & patterns)102 void mlir::populateTensorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
103                                          int64_t byteCountThreshold,
104                                          RewritePatternSet &patterns) {
105   patterns.add<TensorExtractPattern>(typeConverter, patterns.getContext(),
106                                      byteCountThreshold);
107 }
108