1684dfe8aSAlex Zinenko //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===// 2684dfe8aSAlex Zinenko // 3684dfe8aSAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4684dfe8aSAlex Zinenko // See https://llvm.org/LICENSE.txt for license information. 5684dfe8aSAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6684dfe8aSAlex Zinenko // 7684dfe8aSAlex Zinenko //===----------------------------------------------------------------------===// 8684dfe8aSAlex Zinenko 9684dfe8aSAlex Zinenko #include "mlir/Conversion/LLVMCommon/VectorPattern.h" 10684dfe8aSAlex Zinenko #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 11684dfe8aSAlex Zinenko 12684dfe8aSAlex Zinenko using namespace mlir; 13684dfe8aSAlex Zinenko 14684dfe8aSAlex Zinenko // For >1-D vector types, extracts the necessary information to iterate over all 15684dfe8aSAlex Zinenko // 1-D subvectors in the underlying llrepresentation of the n-D vector 16684dfe8aSAlex Zinenko // Iterates on the llvm array type until we hit a non-array type (which is 17684dfe8aSAlex Zinenko // asserted to be an llvm vector type). 18684dfe8aSAlex Zinenko LLVM::detail::NDVectorTypeInfo 19684dfe8aSAlex Zinenko LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType, 20684dfe8aSAlex Zinenko LLVMTypeConverter &converter) { 21684dfe8aSAlex Zinenko assert(vectorType.getRank() > 1 && "expected >1D vector type"); 22684dfe8aSAlex Zinenko NDVectorTypeInfo info; 23684dfe8aSAlex Zinenko info.llvmNDVectorTy = converter.convertType(vectorType); 24684dfe8aSAlex Zinenko if (!info.llvmNDVectorTy || !LLVM::isCompatibleType(info.llvmNDVectorTy)) { 25684dfe8aSAlex Zinenko info.llvmNDVectorTy = nullptr; 26684dfe8aSAlex Zinenko return info; 27684dfe8aSAlex Zinenko } 28684dfe8aSAlex Zinenko info.arraySizes.reserve(vectorType.getRank() - 1); 29684dfe8aSAlex Zinenko auto llvmTy = info.llvmNDVectorTy; 30684dfe8aSAlex Zinenko while (llvmTy.isa<LLVM::LLVMArrayType>()) { 31684dfe8aSAlex Zinenko info.arraySizes.push_back( 32684dfe8aSAlex Zinenko llvmTy.cast<LLVM::LLVMArrayType>().getNumElements()); 33684dfe8aSAlex Zinenko llvmTy = llvmTy.cast<LLVM::LLVMArrayType>().getElementType(); 34684dfe8aSAlex Zinenko } 35684dfe8aSAlex Zinenko if (!LLVM::isCompatibleVectorType(llvmTy)) 36684dfe8aSAlex Zinenko return info; 37684dfe8aSAlex Zinenko info.llvm1DVectorTy = llvmTy; 38684dfe8aSAlex Zinenko return info; 39684dfe8aSAlex Zinenko } 40684dfe8aSAlex Zinenko 41684dfe8aSAlex Zinenko // Express `linearIndex` in terms of coordinates of `basis`. 42684dfe8aSAlex Zinenko // Returns the empty vector when linearIndex is out of the range [0, P] where 43684dfe8aSAlex Zinenko // P is the product of all the basis coordinates. 44684dfe8aSAlex Zinenko // 45684dfe8aSAlex Zinenko // Prerequisites: 46684dfe8aSAlex Zinenko // Basis is an array of nonnegative integers (signed type inherited from 47684dfe8aSAlex Zinenko // vector shape type). 48684dfe8aSAlex Zinenko SmallVector<int64_t, 4> LLVM::detail::getCoordinates(ArrayRef<int64_t> basis, 49684dfe8aSAlex Zinenko unsigned linearIndex) { 50684dfe8aSAlex Zinenko SmallVector<int64_t, 4> res; 51684dfe8aSAlex Zinenko res.reserve(basis.size()); 52684dfe8aSAlex Zinenko for (unsigned basisElement : llvm::reverse(basis)) { 53684dfe8aSAlex Zinenko res.push_back(linearIndex % basisElement); 54684dfe8aSAlex Zinenko linearIndex = linearIndex / basisElement; 55684dfe8aSAlex Zinenko } 56684dfe8aSAlex Zinenko if (linearIndex > 0) 57684dfe8aSAlex Zinenko return {}; 58684dfe8aSAlex Zinenko std::reverse(res.begin(), res.end()); 59684dfe8aSAlex Zinenko return res; 60684dfe8aSAlex Zinenko } 61684dfe8aSAlex Zinenko 62684dfe8aSAlex Zinenko // Iterate of linear index, convert to coords space and insert splatted 1-D 63684dfe8aSAlex Zinenko // vector in each position. 64684dfe8aSAlex Zinenko void LLVM::detail::nDVectorIterate(const LLVM::detail::NDVectorTypeInfo &info, 65684dfe8aSAlex Zinenko OpBuilder &builder, 66684dfe8aSAlex Zinenko function_ref<void(ArrayAttr)> fun) { 67684dfe8aSAlex Zinenko unsigned ub = 1; 68684dfe8aSAlex Zinenko for (auto s : info.arraySizes) 69684dfe8aSAlex Zinenko ub *= s; 70684dfe8aSAlex Zinenko for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) { 71684dfe8aSAlex Zinenko auto coords = getCoordinates(info.arraySizes, linearIndex); 72684dfe8aSAlex Zinenko // Linear index is out of bounds, we are done. 73684dfe8aSAlex Zinenko if (coords.empty()) 74684dfe8aSAlex Zinenko break; 75684dfe8aSAlex Zinenko assert(coords.size() == info.arraySizes.size()); 76684dfe8aSAlex Zinenko auto position = builder.getI64ArrayAttr(coords); 77684dfe8aSAlex Zinenko fun(position); 78684dfe8aSAlex Zinenko } 79684dfe8aSAlex Zinenko } 80684dfe8aSAlex Zinenko 81684dfe8aSAlex Zinenko LogicalResult LLVM::detail::handleMultidimensionalVectors( 82684dfe8aSAlex Zinenko Operation *op, ValueRange operands, LLVMTypeConverter &typeConverter, 83684dfe8aSAlex Zinenko std::function<Value(Type, ValueRange)> createOperand, 84684dfe8aSAlex Zinenko ConversionPatternRewriter &rewriter) { 85684dfe8aSAlex Zinenko auto resultNDVectorType = op->getResult(0).getType().cast<VectorType>(); 86684dfe8aSAlex Zinenko 87684dfe8aSAlex Zinenko SmallVector<Type> operand1DVectorTypes; 88684dfe8aSAlex Zinenko for (Value operand : op->getOperands()) { 89684dfe8aSAlex Zinenko auto operandNDVectorType = operand.getType().cast<VectorType>(); 90684dfe8aSAlex Zinenko auto operandTypeInfo = 91684dfe8aSAlex Zinenko extractNDVectorTypeInfo(operandNDVectorType, typeConverter); 92684dfe8aSAlex Zinenko operand1DVectorTypes.push_back(operandTypeInfo.llvm1DVectorTy); 93684dfe8aSAlex Zinenko } 94684dfe8aSAlex Zinenko auto resultTypeInfo = 95684dfe8aSAlex Zinenko extractNDVectorTypeInfo(resultNDVectorType, typeConverter); 96684dfe8aSAlex Zinenko auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy; 97684dfe8aSAlex Zinenko auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy; 98684dfe8aSAlex Zinenko auto loc = op->getLoc(); 99684dfe8aSAlex Zinenko Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy); 100684dfe8aSAlex Zinenko nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayAttr position) { 101684dfe8aSAlex Zinenko // For this unrolled `position` corresponding to the `linearIndex`^th 102684dfe8aSAlex Zinenko // element, extract operand vectors 103684dfe8aSAlex Zinenko SmallVector<Value, 4> extractedOperands; 104*e4853be2SMehdi Amini for (const auto &operand : llvm::enumerate(operands)) { 105684dfe8aSAlex Zinenko extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>( 106684dfe8aSAlex Zinenko loc, operand1DVectorTypes[operand.index()], operand.value(), 107684dfe8aSAlex Zinenko position)); 108684dfe8aSAlex Zinenko } 109684dfe8aSAlex Zinenko Value newVal = createOperand(result1DVectorTy, extractedOperands); 110684dfe8aSAlex Zinenko desc = rewriter.create<LLVM::InsertValueOp>(loc, resultNDVectoryTy, desc, 111684dfe8aSAlex Zinenko newVal, position); 112684dfe8aSAlex Zinenko }); 113684dfe8aSAlex Zinenko rewriter.replaceOp(op, desc); 114684dfe8aSAlex Zinenko return success(); 115684dfe8aSAlex Zinenko } 116684dfe8aSAlex Zinenko 117684dfe8aSAlex Zinenko LogicalResult LLVM::detail::vectorOneToOneRewrite( 118684dfe8aSAlex Zinenko Operation *op, StringRef targetOp, ValueRange operands, 119684dfe8aSAlex Zinenko LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter) { 120684dfe8aSAlex Zinenko assert(!operands.empty()); 121684dfe8aSAlex Zinenko 122684dfe8aSAlex Zinenko // Cannot convert ops if their operands are not of LLVM type. 123684dfe8aSAlex Zinenko if (!llvm::all_of(operands.getTypes(), 124684dfe8aSAlex Zinenko [](Type t) { return isCompatibleType(t); })) 125684dfe8aSAlex Zinenko return failure(); 126684dfe8aSAlex Zinenko 127684dfe8aSAlex Zinenko auto llvmNDVectorTy = operands[0].getType(); 128684dfe8aSAlex Zinenko if (!llvmNDVectorTy.isa<LLVM::LLVMArrayType>()) 129684dfe8aSAlex Zinenko return oneToOneRewrite(op, targetOp, operands, typeConverter, rewriter); 130684dfe8aSAlex Zinenko 131684dfe8aSAlex Zinenko auto callback = [op, targetOp, &rewriter](Type llvm1DVectorTy, 132684dfe8aSAlex Zinenko ValueRange operands) { 133684dfe8aSAlex Zinenko OperationState state(op->getLoc(), targetOp); 134684dfe8aSAlex Zinenko state.addTypes(llvm1DVectorTy); 135684dfe8aSAlex Zinenko state.addOperands(operands); 136684dfe8aSAlex Zinenko state.addAttributes(op->getAttrs()); 137684dfe8aSAlex Zinenko return rewriter.createOperation(state)->getResult(0); 138684dfe8aSAlex Zinenko }; 139684dfe8aSAlex Zinenko 140684dfe8aSAlex Zinenko return handleMultidimensionalVectors(op, operands, typeConverter, callback, 141684dfe8aSAlex Zinenko rewriter); 142684dfe8aSAlex Zinenko } 143