1 //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LLVMCommon/VectorPattern.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11 
12 using namespace mlir;
13 
14 // For >1-D vector types, extracts the necessary information to iterate over all
15 // 1-D subvectors in the underlying llrepresentation of the n-D vector
16 // Iterates on the llvm array type until we hit a non-array type (which is
17 // asserted to be an llvm vector type).
18 LLVM::detail::NDVectorTypeInfo
extractNDVectorTypeInfo(VectorType vectorType,LLVMTypeConverter & converter)19 LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType,
20                                       LLVMTypeConverter &converter) {
21   assert(vectorType.getRank() > 1 && "expected >1D vector type");
22   NDVectorTypeInfo info;
23   info.llvmNDVectorTy = converter.convertType(vectorType);
24   if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) {
25     info.llvmNDVectorTy = nullptr;
26     return info;
27   }
28   info.arraySizes.reserve(vectorType.getRank() - 1);
29   auto llvmTy = info.llvmNDVectorTy;
30   while (llvmTy.isa<LLVM::LLVMArrayType>()) {
31     info.arraySizes.push_back(
32         llvmTy.cast<LLVM::LLVMArrayType>().getNumElements());
33     llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType();
34   }
35   if (!LLVM::isCompatibleVectorType(llvmTy))
36     return info;
37   info.llvm1DVectorTy = llvmTy;
38   return info;
39 }
40 
41 // Express `linearIndex` in terms of coordinates of `basis`.
42 // Returns the empty vector when linearIndex is out of the range [0, P] where
43 // P is the product of all the basis coordinates.
44 //
45 // Prerequisites:
46 //   Basis is an array of nonnegative integers (signed type inherited from
47 //   vector shape type).
getCoordinates(ArrayRef<int64_t> basis,unsigned linearIndex)48 SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis,
49                                                      unsigned linearIndex) {
50   SmallVector<int64_t, 4> res;
51   res.reserve(basis.size());
52   for (unsigned basisElement : llvm::reverse(basis)) {
53     res.push_back(linearIndex % basisElement);
54     linearIndex = linearIndex / basisElement;
55   }
56   if (linearIndex > 0)
57     return {};
58   std::reverse(res.begin(), res.end());
59   return res;
60 }
61 
62 // Iterate of linear index, convert to coords space and insert splatted 1-D
63 // vector in each position.
nDVectorIterate(const LLVM::detail::NDVectorTypeInfo & info,OpBuilder & builder,function_ref<void (ArrayAttr)> fun)64 void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info,
65                                    OpBuilder &builder,
66                                    function_ref<void(ArrayAttr)> fun) {
67   unsigned ub = 1;
68   for (auto s : info.arraySizes)
69     ub *= s;
70   for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
71     auto coords = getCoordinates(info.arraySizes, linearIndex);
72     // Linear index is out of bounds, we are done.
73     if (coords.empty())
74       break;
75     assert(coords.size() == info.arraySizes.size());
76     auto position = builder.getI64ArrayAttr(coords);
77     fun(position);
78   }
79 }
80 
handleMultidimensionalVectors(Operation * op,ValueRange operands,LLVMTypeConverter & typeConverter,std::function<Value (Type,ValueRange)> createOperand,ConversionPatternRewriter & rewriter)81 LogicalResult LLVM::detail::handleMultidimensionalVectors(
82     Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter,
83     std::function<Value(Type, ValueRange)> createOperand,
84     ConversionPatternRewriter &rewriter) {
85   auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>();
86 
87   SmallVector<Type> operand1DVectorTypes;
88   for (Value operand : op->getOperands()) {
89     auto operandNDVectorType = operand.getType().cast<VectorType>();
90     auto operandTypeInfo =
91         extractNDVectorTypeInfo(operandNDVectorType, typeConverter);
92     operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy);
93   }
94   auto resultTypeInfo =
95       extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
96   auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
97   auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
98   auto loc = op->getLoc();
99   Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
100   nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) {
101     // For this unrolled `position` corresponding to the `linearIndex`^th
102     // element, extract operand vectors
103     SmallVector<Value, 4> extractedOperands;
104     for (const auto &operand : llvm::enumerate(operands)) {
105       extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
106           loc, operand1DVectorTypes[operand.index()], operand.value(),
107           position));
108     }
109     Value newVal = createOperand(result1DVectorTy, extractedOperands);
110     desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc,
111                                                 newVal, position);
112   });
113   rewriter.replaceOp(op, desc);
114   return success();
115 }
116 
vectorOneToOneRewrite(Operation * op,StringRef targetOp,ValueRange operands,LLVMTypeConverter & typeConverter,ConversionPatternRewriter & rewriter)117 LogicalResult LLVM::detail::vectorOneToOneRewrite(
118     Operation *op, StringRef targetOp, ValueRange operands,
119     LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
120   assert(!operands.empty());
121 
122   // Cannot convert ops if their operands are not of LLVM type.
123   if (!llvm::all_of(operands.getTypes(), isCompatibleType))
124     return failure();
125 
126   auto llvmNDVectorTy = operands[0].getType();
127   if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>())
128     return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter);
129 
130   auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy,
131                                             ValueRange operands) {
132     return rewriter
133         .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
134                 llvm1DVectorTy, op->getAttrs())
135         ->getResult(0);
136   };
137 
138   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
139                                        rewriter);
140 }
141