1684dfe8aSAlex Zinenko //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
2684dfe8aSAlex Zinenko //
3684dfe8aSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4684dfe8aSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5684dfe8aSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6684dfe8aSAlex Zinenko //
7684dfe8aSAlex Zinenko //===----------------------------------------------------------------------===//
8684dfe8aSAlex Zinenko
9684dfe8aSAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
10684dfe8aSAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11684dfe8aSAlex Zinenko
12684dfe8aSAlex Zinenko using namespace mlir;
13684dfe8aSAlex Zinenko
14684dfe8aSAlex Zinenko // For >1-D vector types, extracts the necessary information to iterate over all
15684dfe8aSAlex Zinenko // 1-D subvectors in the underlying llrepresentation of the n-D vector
16684dfe8aSAlex Zinenko // Iterates on the llvm array type until we hit a non-array type (which is
17684dfe8aSAlex Zinenko // asserted to be an llvm vector type).
18684dfe8aSAlex Zinenko LLVM::detail::NDVectorTypeInfo
extractNDVectorTypeInfo(VectorType vectorType,LLVMTypeConverter & converter)19684dfe8aSAlex Zinenko LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType,
20684dfe8aSAlex Zinenko LLVMTypeConverter &converter) {
21684dfe8aSAlex Zinenko assert(vectorType.getRank() > 1 && "expected >1D vector type");
22684dfe8aSAlex Zinenko NDVectorTypeInfo info;
23684dfe8aSAlex Zinenko info.llvmNDVectorTy = converter.convertType(vectorType);
24684dfe8aSAlex Zinenko if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
25684dfe8aSAlex Zinenko info.llvmNDVectorTy = nullptr;
26684dfe8aSAlex Zinenko return info;
27684dfe8aSAlex Zinenko }
28684dfe8aSAlex Zinenko info.arraySizes.reserve(vectorType.getRank() - 1);
29684dfe8aSAlex Zinenko auto llvmTy = info.llvmNDVectorTy;
30684dfe8aSAlex Zinenko while (llvmTy.isa<LLVM::LLVMArrayType>()) {
31684dfe8aSAlex Zinenko info.arraySizes.push_back(
32684dfe8aSAlex Zinenko llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
33684dfe8aSAlex Zinenko llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
34684dfe8aSAlex Zinenko }
35684dfe8aSAlex Zinenko if (!LLVM::isCompatibleVectorType(llvmTy))
36684dfe8aSAlex Zinenko return info;
37684dfe8aSAlex Zinenko info.llvm1DVectorTy = llvmTy;
38684dfe8aSAlex Zinenko return info;
39684dfe8aSAlex Zinenko }
40684dfe8aSAlex Zinenko
41684dfe8aSAlex Zinenko // Express `linearIndex` in terms of coordinates of `basis`.
42684dfe8aSAlex Zinenko // Returns the empty vector when linearIndex is out of the range [0, P] where
43684dfe8aSAlex Zinenko // P is the product of all the basis coordinates.
44684dfe8aSAlex Zinenko //
45684dfe8aSAlex Zinenko // Prerequisites:
46684dfe8aSAlex Zinenko // Basis is an array of nonnegative integers (signed type inherited from
47684dfe8aSAlex Zinenko // vector shape type).
getCoordinates(ArrayRef<int64_t> basis,unsigned linearIndex)48684dfe8aSAlex Zinenko SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis,
49684dfe8aSAlex Zinenko unsigned linearIndex) {
50684dfe8aSAlex Zinenko SmallVector<int64_t, 4> res;
51684dfe8aSAlex Zinenko res.reserve(basis.size());
52684dfe8aSAlex Zinenko for (unsigned basisElement : llvm::reverse(basis)) {
53684dfe8aSAlex Zinenko res.push_back(linearIndex % basisElement);
54684dfe8aSAlex Zinenko linearIndex = linearIndex / basisElement;
55684dfe8aSAlex Zinenko }
56684dfe8aSAlex Zinenko if (linearIndex > 0)
57684dfe8aSAlex Zinenko return {};
58684dfe8aSAlex Zinenko std::reverse(res.begin(), res.end());
59684dfe8aSAlex Zinenko return res;
60684dfe8aSAlex Zinenko }
61684dfe8aSAlex Zinenko
62684dfe8aSAlex Zinenko // Iterate of linear index, convert to coords space and insert splatted 1-D
63684dfe8aSAlex Zinenko // vector in each position.
nDVectorIterate(const LLVM::detail::NDVectorTypeInfo & info,OpBuilder & builder,function_ref<void (ArrayAttr)> fun)64684dfe8aSAlex Zinenko void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
65684dfe8aSAlex Zinenko OpBuilder &builder,
66684dfe8aSAlex Zinenko function_ref<void(ArrayAttr)> fun) {
67684dfe8aSAlex Zinenko unsigned ub = 1;
68684dfe8aSAlex Zinenko for (auto s : info.arraySizes)
69684dfe8aSAlex Zinenko ub *= s;
70684dfe8aSAlex Zinenko for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
71684dfe8aSAlex Zinenko auto coords = getCoordinates(info.arraySizes, linearIndex);
72684dfe8aSAlex Zinenko // Linear index is out of bounds, we are done.
73684dfe8aSAlex Zinenko if (coords.empty())
74684dfe8aSAlex Zinenko break;
75684dfe8aSAlex Zinenko assert(coords.size() == info.arraySizes.size());
76684dfe8aSAlex Zinenko auto position = builder.getI64ArrayAttr(coords);
77684dfe8aSAlex Zinenko fun(position);
78684dfe8aSAlex Zinenko }
79684dfe8aSAlex Zinenko }
80684dfe8aSAlex Zinenko
handleMultidimensionalVectors(Operation * op,ValueRange operands,LLVMTypeConverter & typeConverter,std::function<Value (Type,ValueRange)> createOperand,ConversionPatternRewriter & rewriter)81684dfe8aSAlex Zinenko LogicalResult LLVM::detail::handleMultidimensionalVectors(
82684dfe8aSAlex Zinenko Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
83684dfe8aSAlex Zinenko std::function<Value(Type, ValueRange)> createOperand,
84684dfe8aSAlex Zinenko ConversionPatternRewriter &rewriter) {
85684dfe8aSAlex Zinenko auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
86684dfe8aSAlex Zinenko
87684dfe8aSAlex Zinenko SmallVector<Type> operand1DVectorTypes;
88684dfe8aSAlex Zinenko for (Value operand : op->getOperands()) {
89684dfe8aSAlex Zinenko auto operandNDVectorType = operand.getType().cast<VectorType>();
90684dfe8aSAlex Zinenko auto operandTypeInfo =
91684dfe8aSAlex Zinenko extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
92684dfe8aSAlex Zinenko operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
93684dfe8aSAlex Zinenko }
94684dfe8aSAlex Zinenko auto resultTypeInfo =
95684dfe8aSAlex Zinenko extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
96684dfe8aSAlex Zinenko auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
97684dfe8aSAlex Zinenko auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
98684dfe8aSAlex Zinenko auto loc = op->getLoc();
99684dfe8aSAlex Zinenko Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
100684dfe8aSAlex Zinenko nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
101684dfe8aSAlex Zinenko // For this unrolled `position` corresponding to the `linearIndex`^th
102684dfe8aSAlex Zinenko // element, extract operand vectors
103684dfe8aSAlex Zinenko SmallVector<Value, 4> extractedOperands;
104e4853be2SMehdi Amini for (const auto &operand : llvm::enumerate(operands)) {
105684dfe8aSAlex Zinenko extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
106684dfe8aSAlex Zinenko loc, operand1DVectorTypes[operand.index()], operand.value(),
107684dfe8aSAlex Zinenko position));
108684dfe8aSAlex Zinenko }
109684dfe8aSAlex Zinenko Value newVal = createOperand(result1DVectorTy, extractedOperands);
110684dfe8aSAlex Zinenko desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
111684dfe8aSAlex Zinenko newVal, position);
112684dfe8aSAlex Zinenko });
113684dfe8aSAlex Zinenko rewriter.replaceOp(op, desc);
114684dfe8aSAlex Zinenko return success();
115684dfe8aSAlex Zinenko }
116684dfe8aSAlex Zinenko
vectorOneToOneRewrite(Operation * op,StringRef targetOp,ValueRange operands,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)117684dfe8aSAlex Zinenko LogicalResult LLVM::detail::vectorOneToOneRewrite(
118684dfe8aSAlex Zinenko Operation *op, StringRef targetOp, ValueRange operands,
119684dfe8aSAlex Zinenko LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
120684dfe8aSAlex Zinenko assert(!operands.empty());
121684dfe8aSAlex Zinenko
122684dfe8aSAlex Zinenko // Cannot convert ops if their operands are not of LLVM type.
123*380a1b20SKazu Hirata if (!llvm::all_of(operands.getTypes(), isCompatibleType))
124684dfe8aSAlex Zinenko return failure();
125684dfe8aSAlex Zinenko
126684dfe8aSAlex Zinenko auto llvmNDVectorTy = operands[0].getType();
127684dfe8aSAlex Zinenko if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
128684dfe8aSAlex Zinenko return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
129684dfe8aSAlex Zinenko
130684dfe8aSAlex Zinenko auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
131684dfe8aSAlex Zinenko ValueRange operands) {
13214ecafd0SChia-hung Duan return rewriter
13314ecafd0SChia-hung Duan .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
13414ecafd0SChia-hung Duan llvm1DVectorTy, op->getAttrs())
13514ecafd0SChia-hung Duan ->getResult(0);
136684dfe8aSAlex Zinenko };
137684dfe8aSAlex Zinenko
138684dfe8aSAlex Zinenko return handleMultidimensionalVectors(op, operands, typeConverter, callback,
139684dfe8aSAlex Zinenko rewriter);
140684dfe8aSAlex Zinenko }
141