1*684dfe8aSAlex Zinenko //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
2*684dfe8aSAlex Zinenko //
3*684dfe8aSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*684dfe8aSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
5*684dfe8aSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*684dfe8aSAlex Zinenko //
7*684dfe8aSAlex Zinenko //===----------------------------------------------------------------------===//
8*684dfe8aSAlex Zinenko 
9*684dfe8aSAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
10*684dfe8aSAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11*684dfe8aSAlex Zinenko 
12*684dfe8aSAlex Zinenko using namespace mlir;
13*684dfe8aSAlex Zinenko 
14*684dfe8aSAlex Zinenko // For >1-D vector types, extracts the necessary information to iterate over all
15*684dfe8aSAlex Zinenko // 1-D subvectors in the underlying llrepresentation of the n-D vector
16*684dfe8aSAlex Zinenko // Iterates on the llvm array type until we hit a non-array type (which is
17*684dfe8aSAlex Zinenko // asserted to be an llvm vector type).
18*684dfe8aSAlex Zinenko LLVM::detail::NDVectorTypeInfo
19*684dfe8aSAlex Zinenko LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType,
20*684dfe8aSAlex Zinenko                                       LLVMTypeConverter &converter) {
21*684dfe8aSAlex Zinenko   assert(vectorType.getRank() > 1 && "expected >1D vector type");
22*684dfe8aSAlex Zinenko   NDVectorTypeInfo info;
23*684dfe8aSAlex Zinenko   info.llvmNDVectorTy = converter.convertType(vectorType);
24*684dfe8aSAlex Zinenko   if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
25*684dfe8aSAlex Zinenko     info.llvmNDVectorTy = nullptr;
26*684dfe8aSAlex Zinenko     return info;
27*684dfe8aSAlex Zinenko   }
28*684dfe8aSAlex Zinenko   info.arraySizes.reserve(vectorType.getRank() - 1);
29*684dfe8aSAlex Zinenko   auto llvmTy = info.llvmNDVectorTy;
30*684dfe8aSAlex Zinenko   while (llvmTy.isa<LLVM::LLVMArrayType>()) {
31*684dfe8aSAlex Zinenko     info.arraySizes.push_back(
32*684dfe8aSAlex Zinenko         llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
33*684dfe8aSAlex Zinenko     llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
34*684dfe8aSAlex Zinenko   }
35*684dfe8aSAlex Zinenko   if (!LLVM::isCompatibleVectorType(llvmTy))
36*684dfe8aSAlex Zinenko     return info;
37*684dfe8aSAlex Zinenko   info.llvm1DVectorTy = llvmTy;
38*684dfe8aSAlex Zinenko   return info;
39*684dfe8aSAlex Zinenko }
40*684dfe8aSAlex Zinenko 
41*684dfe8aSAlex Zinenko // Express `linearIndex` in terms of coordinates of `basis`.
42*684dfe8aSAlex Zinenko // Returns the empty vector when linearIndex is out of the range [0, P] where
43*684dfe8aSAlex Zinenko // P is the product of all the basis coordinates.
44*684dfe8aSAlex Zinenko //
45*684dfe8aSAlex Zinenko // Prerequisites:
46*684dfe8aSAlex Zinenko //   Basis is an array of nonnegative integers (signed type inherited from
47*684dfe8aSAlex Zinenko //   vector shape type).
48*684dfe8aSAlex Zinenko SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis,
49*684dfe8aSAlex Zinenko                                                      unsigned linearIndex) {
50*684dfe8aSAlex Zinenko   SmallVector<int64_t, 4> res;
51*684dfe8aSAlex Zinenko   res.reserve(basis.size());
52*684dfe8aSAlex Zinenko   for (unsigned basisElement : llvm::reverse(basis)) {
53*684dfe8aSAlex Zinenko     res.push_back(linearIndex % basisElement);
54*684dfe8aSAlex Zinenko     linearIndex = linearIndex / basisElement;
55*684dfe8aSAlex Zinenko   }
56*684dfe8aSAlex Zinenko   if (linearIndex > 0)
57*684dfe8aSAlex Zinenko     return {};
58*684dfe8aSAlex Zinenko   std::reverse(res.begin(), res.end());
59*684dfe8aSAlex Zinenko   return res;
60*684dfe8aSAlex Zinenko }
61*684dfe8aSAlex Zinenko 
62*684dfe8aSAlex Zinenko // Iterate of linear index, convert to coords space and insert splatted 1-D
63*684dfe8aSAlex Zinenko // vector in each position.
64*684dfe8aSAlex Zinenko void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
65*684dfe8aSAlex Zinenko                                    OpBuilder &builder,
66*684dfe8aSAlex Zinenko                                    function_ref<void(ArrayAttr)> fun) {
67*684dfe8aSAlex Zinenko   unsigned ub = 1;
68*684dfe8aSAlex Zinenko   for (auto s : info.arraySizes)
69*684dfe8aSAlex Zinenko     ub *= s;
70*684dfe8aSAlex Zinenko   for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
71*684dfe8aSAlex Zinenko     auto coords = getCoordinates(info.arraySizes, linearIndex);
72*684dfe8aSAlex Zinenko     // Linear index is out of bounds, we are done.
73*684dfe8aSAlex Zinenko     if (coords.empty())
74*684dfe8aSAlex Zinenko       break;
75*684dfe8aSAlex Zinenko     assert(coords.size() == info.arraySizes.size());
76*684dfe8aSAlex Zinenko     auto position = builder.getI64ArrayAttr(coords);
77*684dfe8aSAlex Zinenko     fun(position);
78*684dfe8aSAlex Zinenko   }
79*684dfe8aSAlex Zinenko }
80*684dfe8aSAlex Zinenko 
81*684dfe8aSAlex Zinenko LogicalResult LLVM::detail::handleMultidimensionalVectors(
82*684dfe8aSAlex Zinenko     Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
83*684dfe8aSAlex Zinenko     std::function<Value(Type, ValueRange)> createOperand,
84*684dfe8aSAlex Zinenko     ConversionPatternRewriter &rewriter) {
85*684dfe8aSAlex Zinenko   auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
86*684dfe8aSAlex Zinenko 
87*684dfe8aSAlex Zinenko   SmallVector<Type> operand1DVectorTypes;
88*684dfe8aSAlex Zinenko   for (Value operand : op->getOperands()) {
89*684dfe8aSAlex Zinenko     auto operandNDVectorType = operand.getType().cast<VectorType>();
90*684dfe8aSAlex Zinenko     auto operandTypeInfo =
91*684dfe8aSAlex Zinenko         extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
92*684dfe8aSAlex Zinenko     operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
93*684dfe8aSAlex Zinenko   }
94*684dfe8aSAlex Zinenko   auto resultTypeInfo =
95*684dfe8aSAlex Zinenko       extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
96*684dfe8aSAlex Zinenko   auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
97*684dfe8aSAlex Zinenko   auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
98*684dfe8aSAlex Zinenko   auto loc = op->getLoc();
99*684dfe8aSAlex Zinenko   Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
100*684dfe8aSAlex Zinenko   nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
101*684dfe8aSAlex Zinenko     // For this unrolled `position` corresponding to the `linearIndex`^th
102*684dfe8aSAlex Zinenko     // element, extract operand vectors
103*684dfe8aSAlex Zinenko     SmallVector<Value, 4> extractedOperands;
104*684dfe8aSAlex Zinenko     for (auto operand : llvm::enumerate(operands)) {
105*684dfe8aSAlex Zinenko       extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
106*684dfe8aSAlex Zinenko           loc, operand1DVectorTypes[operand.index()], operand.value(),
107*684dfe8aSAlex Zinenko           position));
108*684dfe8aSAlex Zinenko     }
109*684dfe8aSAlex Zinenko     Value newVal = createOperand(result1DVectorTy, extractedOperands);
110*684dfe8aSAlex Zinenko     desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
111*684dfe8aSAlex Zinenko                                                 newVal, position);
112*684dfe8aSAlex Zinenko   });
113*684dfe8aSAlex Zinenko   rewriter.replaceOp(op, desc);
114*684dfe8aSAlex Zinenko   return success();
115*684dfe8aSAlex Zinenko }
116*684dfe8aSAlex Zinenko 
117*684dfe8aSAlex Zinenko LogicalResult LLVM::detail::vectorOneToOneRewrite(
118*684dfe8aSAlex Zinenko     Operation *op, StringRef targetOp, ValueRange operands,
119*684dfe8aSAlex Zinenko     LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
120*684dfe8aSAlex Zinenko   assert(!operands.empty());
121*684dfe8aSAlex Zinenko 
122*684dfe8aSAlex Zinenko   // Cannot convert ops if their operands are not of LLVM type.
123*684dfe8aSAlex Zinenko   if (!llvm::all_of(operands.getTypes(),
124*684dfe8aSAlex Zinenko                     [](Type t) { return isCompatibleType(t); }))
125*684dfe8aSAlex Zinenko     return failure();
126*684dfe8aSAlex Zinenko 
127*684dfe8aSAlex Zinenko   auto llvmNDVectorTy = operands[0].getType();
128*684dfe8aSAlex Zinenko   if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
129*684dfe8aSAlex Zinenko     return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
130*684dfe8aSAlex Zinenko 
131*684dfe8aSAlex Zinenko   auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
132*684dfe8aSAlex Zinenko                                             ValueRange operands) {
133*684dfe8aSAlex Zinenko     OperationState state(op->getLoc(), targetOp);
134*684dfe8aSAlex Zinenko     state.addTypes(llvm1DVectorTy);
135*684dfe8aSAlex Zinenko     state.addOperands(operands);
136*684dfe8aSAlex Zinenko     state.addAttributes(op->getAttrs());
137*684dfe8aSAlex Zinenko     return rewriter.createOperation(state)->getResult(0);
138*684dfe8aSAlex Zinenko   };
139*684dfe8aSAlex Zinenko 
140*684dfe8aSAlex Zinenko   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
141*684dfe8aSAlex Zinenko                                        rewriter);
142*684dfe8aSAlex Zinenko }
143