1 //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 11 12 using namespace mlir; 13 14 // For >1-D vector types, extracts the necessary information to iterate over all 15 // 1-D subvectors in the underlying llrepresentation of the n-D vector 16 // Iterates on the llvm array type until we hit a non-array type (which is 17 // asserted to be an llvm vector type). 18 LLVM::detail::NDVectorTypeInfo 19 LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType, 20 LLVMTypeConverter &converter) { 21 assert(vectorType.getRank() > 1 && "expected >1D vector type"); 22 NDVectorTypeInfo info; 23 info.llvmNDVectorTy = converter.convertType(vectorType); 24 if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) { 25 info.llvmNDVectorTy = nullptr; 26 return info; 27 } 28 info.arraySizes.reserve(vectorType.getRank() - 1); 29 auto llvmTy = info.llvmNDVectorTy; 30 while (llvmTy.isa<LLVM::LLVMArrayType>()) { 31 info.arraySizes.push_back( 32 llvmTy.cast<LLVM::LLVMArrayType>().getNumElements()); 33 llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType(); 34 } 35 if (!LLVM::isCompatibleVectorType(llvmTy)) 36 return info; 37 info.llvm1DVectorTy = llvmTy; 38 return info; 39 } 40 41 // Express `linearIndex` in terms of coordinates of `basis`. 42 // Returns the empty vector when linearIndex is out of the range [0, P] where 43 // P is the product of all the basis coordinates. 44 // 45 // Prerequisites: 46 // Basis is an array of nonnegative integers (signed type inherited from 47 // vector shape type). 48 SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis, 49 unsigned linearIndex) { 50 SmallVector<int64_t, 4> res; 51 res.reserve(basis.size()); 52 for (unsigned basisElement : llvm::reverse(basis)) { 53 res.push_back(linearIndex % basisElement); 54 linearIndex = linearIndex / basisElement; 55 } 56 if (linearIndex > 0) 57 return {}; 58 std::reverse(res.begin(), res.end()); 59 return res; 60 } 61 62 // Iterate of linear index, convert to coords space and insert splatted 1-D 63 // vector in each position. 64 void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info, 65 OpBuilder &builder, 66 function_ref<void(ArrayAttr)> fun) { 67 unsigned ub = 1; 68 for (auto s : info.arraySizes) 69 ub *= s; 70 for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { 71 auto coords = getCoordinates(info.arraySizes, linearIndex); 72 // Linear index is out of bounds, we are done. 73 if (coords.empty()) 74 break; 75 assert(coords.size() == info.arraySizes.size()); 76 auto position = builder.getI64ArrayAttr(coords); 77 fun(position); 78 } 79 } 80 81 LogicalResult LLVM::detail::handleMultidimensionalVectors( 82 Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, 83 std::function<Value(Type, ValueRange)> createOperand, 84 ConversionPatternRewriter &rewriter) { 85 auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>(); 86 87 SmallVector<Type> operand1DVectorTypes; 88 for (Value operand : op->getOperands()) { 89 auto operandNDVectorType = operand.getType().cast<VectorType>(); 90 auto operandTypeInfo = 91 extractNDVectorTypeInfo(operandNDVectorType, typeConverter); 92 operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy); 93 } 94 auto resultTypeInfo = 95 extractNDVectorTypeInfo(resultNDVectorType, typeConverter); 96 auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; 97 auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; 98 auto loc = op->getLoc(); 99 Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy); 100 nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) { 101 // For this unrolled `position` corresponding to the `linearIndex`^th 102 // element, extract operand vectors 103 SmallVector<Value, 4> extractedOperands; 104 for (const auto &operand : llvm::enumerate(operands)) { 105 extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( 106 loc, operand1DVectorTypes[operand.index()], operand.value(), 107 position)); 108 } 109 Value newVal = createOperand(result1DVectorTy, extractedOperands); 110 desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc, 111 newVal, position); 112 }); 113 rewriter.replaceOp(op, desc); 114 return success(); 115 } 116 117 LogicalResult LLVM::detail::vectorOneToOneRewrite( 118 Operation *op, StringRef targetOp, ValueRange operands, 119 LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { 120 assert(!operands.empty()); 121 122 // Cannot convert ops if their operands are not of LLVM type. 123 if (!llvm::all_of(operands.getTypes(), 124 [](Type t) { return isCompatibleType(t); })) 125 return failure(); 126 127 auto llvmNDVectorTy = operands[0].getType(); 128 if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) 129 return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); 130 131 auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy, 132 ValueRange operands) { 133 OperationState state(op->getLoc(), targetOp); 134 state.addTypes(llvm1DVectorTy); 135 state.addOperands(operands); 136 state.addAttributes(op->getAttrs()); 137 return rewriter.createOperation(state)->getResult(0); 138 }; 139 140 return handleMultidimensionalVectors(op, operands, typeConverter, callback, 141 rewriter); 142 } 143