1*684dfe8aSAlex Zinenko //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===// 2*684dfe8aSAlex Zinenko // 3*684dfe8aSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4*684dfe8aSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5*684dfe8aSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6*684dfe8aSAlex Zinenko // 7*684dfe8aSAlex Zinenko //===----------------------------------------------------------------------===// 8*684dfe8aSAlex Zinenko 9*684dfe8aSAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 10*684dfe8aSAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 11*684dfe8aSAlex Zinenko 12*684dfe8aSAlex Zinenko using namespace mlir; 13*684dfe8aSAlex Zinenko 14*684dfe8aSAlex Zinenko // For >1-D vector types, extracts the necessary information to iterate over all 15*684dfe8aSAlex Zinenko // 1-D subvectors in the underlying llrepresentation of the n-D vector 16*684dfe8aSAlex Zinenko // Iterates on the llvm array type until we hit a non-array type (which is 17*684dfe8aSAlex Zinenko // asserted to be an llvm vector type). 18*684dfe8aSAlex Zinenko LLVM::detail::NDVectorTypeInfo 19*684dfe8aSAlex Zinenko LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType, 20*684dfe8aSAlex Zinenko LLVMTypeConverter &converter) { 21*684dfe8aSAlex Zinenko assert(vectorType.getRank() > 1 && "expected >1D vector type"); 22*684dfe8aSAlex Zinenko NDVectorTypeInfo info; 23*684dfe8aSAlex Zinenko info.llvmNDVectorTy = converter.convertType(vectorType); 24*684dfe8aSAlex Zinenko if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) { 25*684dfe8aSAlex Zinenko info.llvmNDVectorTy = nullptr; 26*684dfe8aSAlex Zinenko return info; 27*684dfe8aSAlex Zinenko } 28*684dfe8aSAlex Zinenko info.arraySizes.reserve(vectorType.getRank() - 1); 29*684dfe8aSAlex Zinenko auto llvmTy = info.llvmNDVectorTy; 30*684dfe8aSAlex Zinenko while (llvmTy.isa<LLVM::LLVMArrayType>()) { 31*684dfe8aSAlex Zinenko info.arraySizes.push_back( 32*684dfe8aSAlex Zinenko llvmTy.cast<LLVM::LLVMArrayType>().getNumElements()); 33*684dfe8aSAlex Zinenko llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType(); 34*684dfe8aSAlex Zinenko } 35*684dfe8aSAlex Zinenko if (!LLVM::isCompatibleVectorType(llvmTy)) 36*684dfe8aSAlex Zinenko return info; 37*684dfe8aSAlex Zinenko info.llvm1DVectorTy = llvmTy; 38*684dfe8aSAlex Zinenko return info; 39*684dfe8aSAlex Zinenko } 40*684dfe8aSAlex Zinenko 41*684dfe8aSAlex Zinenko // Express `linearIndex` in terms of coordinates of `basis`. 42*684dfe8aSAlex Zinenko // Returns the empty vector when linearIndex is out of the range [0, P] where 43*684dfe8aSAlex Zinenko // P is the product of all the basis coordinates. 44*684dfe8aSAlex Zinenko // 45*684dfe8aSAlex Zinenko // Prerequisites: 46*684dfe8aSAlex Zinenko // Basis is an array of nonnegative integers (signed type inherited from 47*684dfe8aSAlex Zinenko // vector shape type). 48*684dfe8aSAlex Zinenko SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis, 49*684dfe8aSAlex Zinenko unsigned linearIndex) { 50*684dfe8aSAlex Zinenko SmallVector<int64_t, 4> res; 51*684dfe8aSAlex Zinenko res.reserve(basis.size()); 52*684dfe8aSAlex Zinenko for (unsigned basisElement : llvm::reverse(basis)) { 53*684dfe8aSAlex Zinenko res.push_back(linearIndex % basisElement); 54*684dfe8aSAlex Zinenko linearIndex = linearIndex / basisElement; 55*684dfe8aSAlex Zinenko } 56*684dfe8aSAlex Zinenko if (linearIndex > 0) 57*684dfe8aSAlex Zinenko return {}; 58*684dfe8aSAlex Zinenko std::reverse(res.begin(), res.end()); 59*684dfe8aSAlex Zinenko return res; 60*684dfe8aSAlex Zinenko } 61*684dfe8aSAlex Zinenko 62*684dfe8aSAlex Zinenko // Iterate of linear index, convert to coords space and insert splatted 1-D 63*684dfe8aSAlex Zinenko // vector in each position. 64*684dfe8aSAlex Zinenko void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info, 65*684dfe8aSAlex Zinenko OpBuilder &builder, 66*684dfe8aSAlex Zinenko function_ref<void(ArrayAttr)> fun) { 67*684dfe8aSAlex Zinenko unsigned ub = 1; 68*684dfe8aSAlex Zinenko for (auto s : info.arraySizes) 69*684dfe8aSAlex Zinenko ub *= s; 70*684dfe8aSAlex Zinenko for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { 71*684dfe8aSAlex Zinenko auto coords = getCoordinates(info.arraySizes, linearIndex); 72*684dfe8aSAlex Zinenko // Linear index is out of bounds, we are done. 73*684dfe8aSAlex Zinenko if (coords.empty()) 74*684dfe8aSAlex Zinenko break; 75*684dfe8aSAlex Zinenko assert(coords.size() == info.arraySizes.size()); 76*684dfe8aSAlex Zinenko auto position = builder.getI64ArrayAttr(coords); 77*684dfe8aSAlex Zinenko fun(position); 78*684dfe8aSAlex Zinenko } 79*684dfe8aSAlex Zinenko } 80*684dfe8aSAlex Zinenko 81*684dfe8aSAlex Zinenko LogicalResult LLVM::detail::handleMultidimensionalVectors( 82*684dfe8aSAlex Zinenko Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, 83*684dfe8aSAlex Zinenko std::function<Value(Type, ValueRange)> createOperand, 84*684dfe8aSAlex Zinenko ConversionPatternRewriter &rewriter) { 85*684dfe8aSAlex Zinenko auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>(); 86*684dfe8aSAlex Zinenko 87*684dfe8aSAlex Zinenko SmallVector<Type> operand1DVectorTypes; 88*684dfe8aSAlex Zinenko for (Value operand : op->getOperands()) { 89*684dfe8aSAlex Zinenko auto operandNDVectorType = operand.getType().cast<VectorType>(); 90*684dfe8aSAlex Zinenko auto operandTypeInfo = 91*684dfe8aSAlex Zinenko extractNDVectorTypeInfo(operandNDVectorType, typeConverter); 92*684dfe8aSAlex Zinenko operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy); 93*684dfe8aSAlex Zinenko } 94*684dfe8aSAlex Zinenko auto resultTypeInfo = 95*684dfe8aSAlex Zinenko extractNDVectorTypeInfo(resultNDVectorType, typeConverter); 96*684dfe8aSAlex Zinenko auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; 97*684dfe8aSAlex Zinenko auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; 98*684dfe8aSAlex Zinenko auto loc = op->getLoc(); 99*684dfe8aSAlex Zinenko Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy); 100*684dfe8aSAlex Zinenko nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) { 101*684dfe8aSAlex Zinenko // For this unrolled `position` corresponding to the `linearIndex`^th 102*684dfe8aSAlex Zinenko // element, extract operand vectors 103*684dfe8aSAlex Zinenko SmallVector<Value, 4> extractedOperands; 104*684dfe8aSAlex Zinenko for (auto operand : llvm::enumerate(operands)) { 105*684dfe8aSAlex Zinenko extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( 106*684dfe8aSAlex Zinenko loc, operand1DVectorTypes[operand.index()], operand.value(), 107*684dfe8aSAlex Zinenko position)); 108*684dfe8aSAlex Zinenko } 109*684dfe8aSAlex Zinenko Value newVal = createOperand(result1DVectorTy, extractedOperands); 110*684dfe8aSAlex Zinenko desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc, 111*684dfe8aSAlex Zinenko newVal, position); 112*684dfe8aSAlex Zinenko }); 113*684dfe8aSAlex Zinenko rewriter.replaceOp(op, desc); 114*684dfe8aSAlex Zinenko return success(); 115*684dfe8aSAlex Zinenko } 116*684dfe8aSAlex Zinenko 117*684dfe8aSAlex Zinenko LogicalResult LLVM::detail::vectorOneToOneRewrite( 118*684dfe8aSAlex Zinenko Operation *op, StringRef targetOp, ValueRange operands, 119*684dfe8aSAlex Zinenko LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { 120*684dfe8aSAlex Zinenko assert(!operands.empty()); 121*684dfe8aSAlex Zinenko 122*684dfe8aSAlex Zinenko // Cannot convert ops if their operands are not of LLVM type. 123*684dfe8aSAlex Zinenko if (!llvm::all_of(operands.getTypes(), 124*684dfe8aSAlex Zinenko [](Type t) { return isCompatibleType(t); })) 125*684dfe8aSAlex Zinenko return failure(); 126*684dfe8aSAlex Zinenko 127*684dfe8aSAlex Zinenko auto llvmNDVectorTy = operands[0].getType(); 128*684dfe8aSAlex Zinenko if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) 129*684dfe8aSAlex Zinenko return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); 130*684dfe8aSAlex Zinenko 131*684dfe8aSAlex Zinenko auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy, 132*684dfe8aSAlex Zinenko ValueRange operands) { 133*684dfe8aSAlex Zinenko OperationState state(op->getLoc(), targetOp); 134*684dfe8aSAlex Zinenko state.addTypes(llvm1DVectorTy); 135*684dfe8aSAlex Zinenko state.addOperands(operands); 136*684dfe8aSAlex Zinenko state.addAttributes(op->getAttrs()); 137*684dfe8aSAlex Zinenko return rewriter.createOperation(state)->getResult(0); 138*684dfe8aSAlex Zinenko }; 139*684dfe8aSAlex Zinenko 140*684dfe8aSAlex Zinenko return handleMultidimensionalVectors(op, operands, typeConverter, callback, 141*684dfe8aSAlex Zinenko rewriter); 142*684dfe8aSAlex Zinenko } 143