Home
last modified time | relevance | path

Searched refs:vectorType (Results 1 – 25 of 44) sorted by relevance

12

/llvm-project-15.0.7/mlir/lib/Dialect/Vector/Transforms/
H A DVectorTransferOpTransforms.cpp266 VectorType vectorType = vector.getType().cast<VectorType>(); in matchAndRewrite() local
272 if (sourceType.getNumElements() != vectorType.getNumElements()) in matchAndRewrite()
282 if (reducedRank != vectorType.getRank()) in matchAndRewrite()
309 VectorType vectorType = vector.getType().cast<VectorType>(); in matchAndRewrite() local
315 if (sourceType.getNumElements() != vectorType.getNumElements()) in matchAndRewrite()
325 if (reducedRank != vectorType.getRank()) in matchAndRewrite()
421 VectorType vectorType = vector.getType().cast<VectorType>(); in matchAndRewrite() local
427 if (vectorType.getRank() <= 1) in matchAndRewrite()
431 getContiguousInnerDim(sourceType, vectorType.getNumElements()); in matchAndRewrite()
479 VectorType vectorType = vector.getType().cast<VectorType>(); in matchAndRewrite() local
[all …]
H A DVectorInsertExtractStridedSliceRewritePatterns.cpp23 auto vectorType = into.getType().cast<VectorType>(); in insertOne() local
24 if (vectorType.getRank() > 1) in insertOne()
27 loc, vectorType, from, into, in insertOne()
34 auto vectorType = vector.getType().cast<VectorType>(); in extractOne() local
35 if (vectorType.getRank() > 1) in extractOne()
38 loc, vectorType.getElementType(), vector, in extractOne()
H A DVectorDistribute.cpp56 auto vectorType = val.getType().cast<VectorType>(); in rewriteWarpOpToScfFor() local
57 int64_t storeSize = vectorType.getShape()[0]; in rewriteWarpOpToScfFor()
877 auto vectorType = reductionOp.getVector().getType().cast<VectorType>(); in matchAndRewrite() local
879 if (vectorType.getRank() != 1) in matchAndRewrite()
883 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) in matchAndRewrite()
893 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); in matchAndRewrite()
/llvm-project-15.0.7/mlir/lib/Conversion/VectorToGPU/
H A DNvGpuSupport.cpp36 auto shape = type.vectorType.getShape(); in inferNumRegistersPerMatrixFragment()
59 info.vectorType = writeOp.getVectorType(); in getWarpMatrixInfo()
62 info.vectorType = op->getResult(0).getType().cast<VectorType>(); in getWarpMatrixInfo()
89 Type elType = type.vectorType.getElementType(); in inferTileWidthInBits()
101 MLIRContext *ctx = type.vectorType.getContext(); in getMmaSyncRegisterType()
104 Type elType = type.vectorType.getElementType(); in getMmaSyncRegisterType()
175 Type elementType = fragmentType.vectorType.getElementType(); in getLaneIdAndValueIdToOperandCoord()
176 ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape(); in getLaneIdAndValueIdToOperandCoord()
211 Type elType = type.vectorType.getElementType(); in getLdMatrixParams()
212 params.fragmentType = type.vectorType; in getLdMatrixParams()
[all …]
H A DVectorToGPU.cpp457 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); in convertConstantOpMmaSync() local
462 op.getLoc(), vectorType, in convertConstantOpMmaSync()
500 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); in creatLdMatrixCompatibleLoads() local
506 loc, vectorType, op.getSource(), indices, in creatLdMatrixCompatibleLoads()
533 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); in createNonLdMatrixLoads() local
536 op.getLoc(), vectorType.getElementType(), in createNonLdMatrixLoads()
537 builder.getZeroAttr(vectorType.getElementType())); in createNonLdMatrixLoads()
549 for (int i = 0; i < vectorType.getShape()[0]; i++) { in createNonLdMatrixLoads()
570 for (int i = 0; i < vectorType.getShape()[0]; i++) { in createNonLdMatrixLoads()
647 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); in convertTransferWriteToStores() local
[all …]
H A DNvGpuSupport.h32 VectorType vectorType; member
/llvm-project-15.0.7/mlir/lib/Conversion/MathToSPIRV/
H A DMathToSPIRV.cpp35 if (auto vectorType = type.dyn_cast<VectorType>()) { in getScalarOrVectorI32Constant() local
36 if (!vectorType.getElementType().isInteger(32)) in getScalarOrVectorI32Constant()
38 SmallVector<int> values(vectorType.getNumElements(), value); in getScalarOrVectorI32Constant()
72 } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { in matchAndRewrite() local
73 floatType = vectorType.getElementType().cast<FloatType>(); in matchAndRewrite()
88 if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { in matchAndRewrite() local
89 assert(vectorType.getRank() == 1); in matchAndRewrite()
90 int count = vectorType.getNumElements(); in matchAndRewrite()
139 if (auto vectorType = type.dyn_cast<VectorType>()) in matchAndRewrite() local
140 bitwidth = vectorType.getElementTypeBitWidth(); in matchAndRewrite()
/llvm-project-15.0.7/mlir/lib/Conversion/VectorToLLVM/
H A DConvertVectorToLLVM.cpp559 auto vectorType = shuffleOp.getVectorType(); in matchAndRewrite() local
568 int64_t rank = vectorType.getRank(); in matchAndRewrite()
616 auto vectorType = extractEltOp.getVectorType(); in matchAndRewrite() local
623 if (vectorType.getRank() == 0) { in matchAndRewrite()
649 auto vectorType = extractOp.getVectorType(); in matchAndRewrite() local
742 if (vectorType.getRank() == 0) { in matchAndRewrite()
1053 Type eltType = vectorType ? vectorType.getElementType() : printType; in matchAndRewrite()
1098 int64_t rank = vectorType ? vectorType.getRank() : 0; in matchAndRewrite()
1099 Type type = vectorType ? vectorType : eltType; in matchAndRewrite()
1123 if (!vectorType) { in emitRanks()
[all …]
/llvm-project-15.0.7/mlir/lib/Dialect/SPIRV/Utils/
H A DLayoutUtils.cpp92 if (auto vectorType = type.dyn_cast<VectorType>()) in decorateType() local
93 return decorateType(vectorType, size, alignment); in decorateType()
101 Type VulkanLayoutUtils::decorateType(VectorType vectorType, in decorateType() argument
104 const auto numElements = vectorType.getNumElements(); in decorateType()
105 auto elementType = vectorType.getElementType(); in decorateType()
/llvm-project-15.0.7/mlir/lib/Conversion/MathToLLVM/
H A DMathToLLVM.cpp70 auto vectorType = resultType.template dyn_cast<VectorType>(); in matchAndRewrite() local
71 if (!vectorType) in matchAndRewrite()
122 auto vectorType = resultType.dyn_cast<VectorType>(); in matchAndRewrite() local
123 if (!vectorType) in matchAndRewrite()
176 auto vectorType = resultType.dyn_cast<VectorType>(); in matchAndRewrite() local
177 if (!vectorType) in matchAndRewrite()
229 auto vectorType = resultType.dyn_cast<VectorType>(); in matchAndRewrite() local
230 if (!vectorType) in matchAndRewrite()
/llvm-project-15.0.7/mlir/unittests/IR/
H A DShapedTypeTest.cpp117 ShapedType vectorType = in TEST() local
122 ASSERT_EQ(vectorType.clone(vectorNewShape), in TEST()
127 ASSERT_EQ(vectorType.clone(vectorNewType), in TEST()
130 ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType), in TEST()
/llvm-project-15.0.7/mlir/lib/Conversion/LLVMCommon/
H A DVectorPattern.cpp19 LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType, in extractNDVectorTypeInfo() argument
21 assert(vectorType.getRank() > 1 && "expected >1D vector type"); in extractNDVectorTypeInfo()
23 info.llvmNDVectorTy = converter.convertType(vectorType); in extractNDVectorTypeInfo()
28 info.arraySizes.reserve(vectorType.getRank() - 1); in extractNDVectorTypeInfo()
H A DTypeConverter.cpp416 Type vectorType = VectorType::get(type.getShape().back(), elementType, in convertVectorType() local
418 assert(LLVM::isCompatibleVectorType(vectorType) && in convertVectorType()
422 vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); in convertVectorType()
423 return vectorType; in convertVectorType()
/llvm-project-15.0.7/mlir/lib/Dialect/Vector/IR/
H A DVectorOps.cpp745 unsigned rank = vectorType ? vectorType.getShape().size() : 0; in verify()
793 auto elementType = vectorType ? vectorType.getElementType() : resType; in verify()
971 if (vectorType.getRank() != 1) in verify()
1033 vectorType.getShape().drop_front(n), vectorType.getElementType())); in inferReturnTypes()
1042 return vectorType && vectorType.getShape().equals({1}) && in isCompatibleReturnTypes()
2894 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back(); in verifyTransferOp()
2992 if (!vectorType) in parse()
3445 if (!vectorType) in parse()
4506 if (vectorType) in extractShape()
4507 res.append(vectorType.getShape().begin(), vectorType.getShape().end()); in extractShape()
[all …]
/llvm-project-15.0.7/mlir/lib/Conversion/ArithmeticToLLVM/
H A DArithmeticToLLVM.cpp178 auto vectorType = resultType.dyn_cast<VectorType>(); in matchAndRewrite() local
179 if (!vectorType) in matchAndRewrite()
213 auto vectorType = resultType.dyn_cast<VectorType>(); in matchAndRewrite() local
214 if (!vectorType) in matchAndRewrite()
/llvm-project-15.0.7/mlir/lib/Dialect/Quant/Utils/
H A DUniformSupport.cpp41 if (auto vectorType = inputType.dyn_cast<VectorType>()) in convert() local
42 return VectorType::get(vectorType.getShape(), elementalType); in convert()
/llvm-project-15.0.7/mlir/lib/Dialect/LLVMIR/IR/
H A DLLVMTypeSyntax.cpp160 if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>()) in printType() local
161 return printArrayOrVectorType(printer, vectorType); in printType()
163 if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) { in printType() local
164 printer << "<? x " << vectorType.getMinNumElements() << " x "; in printType()
165 dispatchPrint(printer, vectorType.getElementType()); in printType()
H A DLLVMDialect.cpp561 if (auto vectorType = type.dyn_cast<VectorType>()) in extractVectorElementType() local
562 return vectorType.getElementType(); in extractVectorElementType()
1353 auto vectorType = vector.getType(); in build() local
1354 auto llvmType = LLVM::getVectorElementType(vectorType); in build()
1389 Type vectorType = getVector().getType(); in verify() local
1390 if (!LLVM::isCompatibleVectorType(vectorType)) in verify()
1393 << vectorType; in verify()
1592 Type vectorType, positionType; in parse() local
1598 parser.parseColonType(vectorType)) in parse()
1601 if (!LLVM::isCompatibleVectorType(vectorType)) in parse()
[all …]
H A DLLVMTypes.cpp915 bool mlir::LLVM::isScalableVectorType(Type vectorType) { in isScalableVectorType() argument
917 (vectorType in isScalableVectorType()
920 return !vectorType.isa<LLVMFixedVectorType>() && in isScalableVectorType()
921 (vectorType.isa<LLVMScalableVectorType>() || in isScalableVectorType()
922 vectorType.cast<VectorType>().isScalable()); in isScalableVectorType()
/llvm-project-15.0.7/mlir/lib/Dialect/SPIRV/IR/
H A DSPIRVTypes.cpp90 if (auto vectorType = type.dyn_cast<VectorType>()) in classof() local
91 return isValid(vectorType); in classof()
129 if (auto vectorType = dyn_cast<VectorType>()) in getNumElements() local
130 return vectorType.getNumElements(); in getNumElements()
185 if (auto vectorType = dyn_cast<VectorType>()) { in getSizeInBytes() local
187 vectorType.getElementType().cast<ScalarType>().getSizeInBytes(); in getSizeInBytes()
190 return *elementSize * vectorType.getNumElements(); in getSizeInBytes()
657 if (auto vectorType = type.dyn_cast<VectorType>()) in classof() local
658 return CompositeType::isValid(vectorType); in classof()
1130 if (auto vectorType = columnType.dyn_cast<VectorType>()) { in isValidColumnType() local
[all …]
/llvm-project-15.0.7/mlir/lib/Dialect/Quant/IR/
H A DQuantOps.cpp62 if (auto vectorType = expressed.dyn_cast<VectorType>()) in isValidQuantizationSpec() local
63 return spec == vectorType.getElementType(); in isValidQuantizationSpec()
/llvm-project-15.0.7/mlir/lib/Dialect/SPIRV/Transforms/
H A DUnifyAliasedResourcePass.cpp87 if (auto vectorType = type.dyn_cast<VectorType>()) { in deduceCanonicalResource() local
88 if (vectorType.getNumElements() % 2 != 0) in deduceCanonicalResource()
96 vectorType.getElementType().getIntOrFloatBitWidth()); in deduceCanonicalResource()
466 auto vectorType = VectorType::get({ratio}, dstElemType); in matchAndRewrite() local
468 loc, vectorType, components); in matchAndRewrite()
/llvm-project-15.0.7/mlir/include/mlir/Dialect/SPIRV/Utils/
H A DLayoutUtils.h67 static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
/llvm-project-15.0.7/mlir/include/mlir/Conversion/LLVMCommon/
H A DVectorPattern.h34 NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
/llvm-project-15.0.7/mlir/include/mlir/Dialect/Vector/IR/
H A DVectorOps.h165 VectorType vectorType);

12