| /llvm-project-15.0.7/mlir/lib/Dialect/Vector/Transforms/ |
| H A D | VectorTransferOpTransforms.cpp | 266 VectorType vectorType = vector.getType().cast<VectorType>(); in matchAndRewrite() local 272 if (sourceType.getNumElements() != vectorType.getNumElements()) in matchAndRewrite() 282 if (reducedRank != vectorType.getRank()) in matchAndRewrite() 309 VectorType vectorType = vector.getType().cast<VectorType>(); in matchAndRewrite() local 315 if (sourceType.getNumElements() != vectorType.getNumElements()) in matchAndRewrite() 325 if (reducedRank != vectorType.getRank()) in matchAndRewrite() 421 VectorType vectorType = vector.getType().cast<VectorType>(); in matchAndRewrite() local 427 if (vectorType.getRank() <= 1) in matchAndRewrite() 431 getContiguousInnerDim(sourceType, vectorType.getNumElements()); in matchAndRewrite() 479 VectorType vectorType = vector.getType().cast<VectorType>(); in matchAndRewrite() local [all …]
|
| H A D | VectorInsertExtractStridedSliceRewritePatterns.cpp | 23 auto vectorType = into.getType().cast<VectorType>(); in insertOne() local 24 if (vectorType.getRank() > 1) in insertOne() 27 loc, vectorType, from, into, in insertOne() 34 auto vectorType = vector.getType().cast<VectorType>(); in extractOne() local 35 if (vectorType.getRank() > 1) in extractOne() 38 loc, vectorType.getElementType(), vector, in extractOne()
|
| H A D | VectorDistribute.cpp | 56 auto vectorType = val.getType().cast<VectorType>(); in rewriteWarpOpToScfFor() local 57 int64_t storeSize = vectorType.getShape()[0]; in rewriteWarpOpToScfFor() 877 auto vectorType = reductionOp.getVector().getType().cast<VectorType>(); in matchAndRewrite() local 879 if (vectorType.getRank() != 1) in matchAndRewrite() 883 if (vectorType.getShape()[0] % warpOp.getWarpSize() != 0) in matchAndRewrite() 893 int64_t numElements = vectorType.getShape()[0] / warpOp.getWarpSize(); in matchAndRewrite()
|
| /llvm-project-15.0.7/mlir/lib/Conversion/VectorToGPU/ |
| H A D | NvGpuSupport.cpp | 36 auto shape = type.vectorType.getShape(); in inferNumRegistersPerMatrixFragment() 59 info.vectorType = writeOp.getVectorType(); in getWarpMatrixInfo() 62 info.vectorType = op->getResult(0).getType().cast<VectorType>(); in getWarpMatrixInfo() 89 Type elType = type.vectorType.getElementType(); in inferTileWidthInBits() 101 MLIRContext *ctx = type.vectorType.getContext(); in getMmaSyncRegisterType() 104 Type elType = type.vectorType.getElementType(); in getMmaSyncRegisterType() 175 Type elementType = fragmentType.vectorType.getElementType(); in getLaneIdAndValueIdToOperandCoord() 176 ArrayRef<int64_t> operandShape = fragmentType.vectorType.getShape(); in getLaneIdAndValueIdToOperandCoord() 211 Type elType = type.vectorType.getElementType(); in getLdMatrixParams() 212 params.fragmentType = type.vectorType; in getLdMatrixParams() [all …]
|
| H A D | VectorToGPU.cpp | 457 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); in convertConstantOpMmaSync() local 462 op.getLoc(), vectorType, in convertConstantOpMmaSync() 500 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); in creatLdMatrixCompatibleLoads() local 506 loc, vectorType, op.getSource(), indices, in creatLdMatrixCompatibleLoads() 533 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); in createNonLdMatrixLoads() local 536 op.getLoc(), vectorType.getElementType(), in createNonLdMatrixLoads() 537 builder.getZeroAttr(vectorType.getElementType())); in createNonLdMatrixLoads() 549 for (int i = 0; i < vectorType.getShape()[0]; i++) { in createNonLdMatrixLoads() 570 for (int i = 0; i < vectorType.getShape()[0]; i++) { in createNonLdMatrixLoads() 647 VectorType vectorType = getMmaSyncVectorOperandType(*regInfo); in convertTransferWriteToStores() local [all …]
|
| H A D | NvGpuSupport.h | 32 VectorType vectorType; member
|
| /llvm-project-15.0.7/mlir/lib/Conversion/MathToSPIRV/ |
| H A D | MathToSPIRV.cpp | 35 if (auto vectorType = type.dyn_cast<VectorType>()) { in getScalarOrVectorI32Constant() local 36 if (!vectorType.getElementType().isInteger(32)) in getScalarOrVectorI32Constant() 38 SmallVector<int> values(vectorType.getNumElements(), value); in getScalarOrVectorI32Constant() 72 } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { in matchAndRewrite() local 73 floatType = vectorType.getElementType().cast<FloatType>(); in matchAndRewrite() 88 if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { in matchAndRewrite() local 89 assert(vectorType.getRank() == 1); in matchAndRewrite() 90 int count = vectorType.getNumElements(); in matchAndRewrite() 139 if (auto vectorType = type.dyn_cast<VectorType>()) in matchAndRewrite() local 140 bitwidth = vectorType.getElementTypeBitWidth(); in matchAndRewrite()
|
| /llvm-project-15.0.7/mlir/lib/Conversion/VectorToLLVM/ |
| H A D | ConvertVectorToLLVM.cpp | 559 auto vectorType = shuffleOp.getVectorType(); in matchAndRewrite() local 568 int64_t rank = vectorType.getRank(); in matchAndRewrite() 616 auto vectorType = extractEltOp.getVectorType(); in matchAndRewrite() local 623 if (vectorType.getRank() == 0) { in matchAndRewrite() 649 auto vectorType = extractOp.getVectorType(); in matchAndRewrite() local 742 if (vectorType.getRank() == 0) { in matchAndRewrite() 1053 Type eltType = vectorType ? vectorType.getElementType() : printType; in matchAndRewrite() 1098 int64_t rank = vectorType ? vectorType.getRank() : 0; in matchAndRewrite() 1099 Type type = vectorType ? vectorType : eltType; in matchAndRewrite() 1123 if (!vectorType) { in emitRanks() [all …]
|
| /llvm-project-15.0.7/mlir/lib/Dialect/SPIRV/Utils/ |
| H A D | LayoutUtils.cpp | 92 if (auto vectorType = type.dyn_cast<VectorType>()) in decorateType() local 93 return decorateType(vectorType, size, alignment); in decorateType() 101 Type VulkanLayoutUtils::decorateType(VectorType vectorType, in decorateType() argument 104 const auto numElements = vectorType.getNumElements(); in decorateType() 105 auto elementType = vectorType.getElementType(); in decorateType()
|
| /llvm-project-15.0.7/mlir/lib/Conversion/MathToLLVM/ |
| H A D | MathToLLVM.cpp | 70 auto vectorType = resultType.template dyn_cast<VectorType>(); in matchAndRewrite() local 71 if (!vectorType) in matchAndRewrite() 122 auto vectorType = resultType.dyn_cast<VectorType>(); in matchAndRewrite() local 123 if (!vectorType) in matchAndRewrite() 176 auto vectorType = resultType.dyn_cast<VectorType>(); in matchAndRewrite() local 177 if (!vectorType) in matchAndRewrite() 229 auto vectorType = resultType.dyn_cast<VectorType>(); in matchAndRewrite() local 230 if (!vectorType) in matchAndRewrite()
|
| /llvm-project-15.0.7/mlir/unittests/IR/ |
| H A D | ShapedTypeTest.cpp | 117 ShapedType vectorType = in TEST() local 122 ASSERT_EQ(vectorType.clone(vectorNewShape), in TEST() 127 ASSERT_EQ(vectorType.clone(vectorNewType), in TEST() 130 ASSERT_EQ(vectorType.clone(vectorNewShape, vectorNewType), in TEST()
|
| /llvm-project-15.0.7/mlir/lib/Conversion/LLVMCommon/ |
| H A D | VectorPattern.cpp | 19 LLVM::detail::extractNDVectorTypeInfo(VectorType vectorType, in extractNDVectorTypeInfo() argument 21 assert(vectorType.getRank() > 1 && "expected >1D vector type"); in extractNDVectorTypeInfo() 23 info.llvmNDVectorTy = converter.convertType(vectorType); in extractNDVectorTypeInfo() 28 info.arraySizes.reserve(vectorType.getRank() - 1); in extractNDVectorTypeInfo()
|
| H A D | TypeConverter.cpp | 416 Type vectorType = VectorType::get(type.getShape().back(), elementType, in convertVectorType() local 418 assert(LLVM::isCompatibleVectorType(vectorType) && in convertVectorType() 422 vectorType = LLVM::LLVMArrayType::get(vectorType, shape[i]); in convertVectorType() 423 return vectorType; in convertVectorType()
|
| /llvm-project-15.0.7/mlir/lib/Dialect/Vector/IR/ |
| H A D | VectorOps.cpp | 745 unsigned rank = vectorType ? vectorType.getShape().size() : 0; in verify() 793 auto elementType = vectorType ? vectorType.getElementType() : resType; in verify() 971 if (vectorType.getRank() != 1) in verify() 1033 vectorType.getShape().drop_front(n), vectorType.getElementType())); in inferReturnTypes() 1042 return vectorType && vectorType.getShape().equals({1}) && in isCompatibleReturnTypes() 2894 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back(); in verifyTransferOp() 2992 if (!vectorType) in parse() 3445 if (!vectorType) in parse() 4506 if (vectorType) in extractShape() 4507 res.append(vectorType.getShape().begin(), vectorType.getShape().end()); in extractShape() [all …]
|
| /llvm-project-15.0.7/mlir/lib/Conversion/ArithmeticToLLVM/ |
| H A D | ArithmeticToLLVM.cpp | 178 auto vectorType = resultType.dyn_cast<VectorType>(); in matchAndRewrite() local 179 if (!vectorType) in matchAndRewrite() 213 auto vectorType = resultType.dyn_cast<VectorType>(); in matchAndRewrite() local 214 if (!vectorType) in matchAndRewrite()
|
| /llvm-project-15.0.7/mlir/lib/Dialect/Quant/Utils/ |
| H A D | UniformSupport.cpp | 41 if (auto vectorType = inputType.dyn_cast<VectorType>()) in convert() local 42 return VectorType::get(vectorType.getShape(), elementalType); in convert()
|
| /llvm-project-15.0.7/mlir/lib/Dialect/LLVMIR/IR/ |
| H A D | LLVMTypeSyntax.cpp | 160 if (auto vectorType = type.dyn_cast<LLVMFixedVectorType>()) in printType() local 161 return printArrayOrVectorType(printer, vectorType); in printType() 163 if (auto vectorType = type.dyn_cast<LLVMScalableVectorType>()) { in printType() local 164 printer << "<? x " << vectorType.getMinNumElements() << " x "; in printType() 165 dispatchPrint(printer, vectorType.getElementType()); in printType()
|
| H A D | LLVMDialect.cpp | 561 if (auto vectorType = type.dyn_cast<VectorType>()) in extractVectorElementType() local 562 return vectorType.getElementType(); in extractVectorElementType() 1353 auto vectorType = vector.getType(); in build() local 1354 auto llvmType = LLVM::getVectorElementType(vectorType); in build() 1389 Type vectorType = getVector().getType(); in verify() local 1390 if (!LLVM::isCompatibleVectorType(vectorType)) in verify() 1393 << vectorType; in verify() 1592 Type vectorType, positionType; in parse() local 1598 parser.parseColonType(vectorType)) in parse() 1601 if (!LLVM::isCompatibleVectorType(vectorType)) in parse() [all …]
|
| H A D | LLVMTypes.cpp | 915 bool mlir::LLVM::isScalableVectorType(Type vectorType) { in isScalableVectorType() argument 917 (vectorType in isScalableVectorType() 920 return !vectorType.isa<LLVMFixedVectorType>() && in isScalableVectorType() 921 (vectorType.isa<LLVMScalableVectorType>() || in isScalableVectorType() 922 vectorType.cast<VectorType>().isScalable()); in isScalableVectorType()
|
| /llvm-project-15.0.7/mlir/lib/Dialect/SPIRV/IR/ |
| H A D | SPIRVTypes.cpp | 90 if (auto vectorType = type.dyn_cast<VectorType>()) in classof() local 91 return isValid(vectorType); in classof() 129 if (auto vectorType = dyn_cast<VectorType>()) in getNumElements() local 130 return vectorType.getNumElements(); in getNumElements() 185 if (auto vectorType = dyn_cast<VectorType>()) { in getSizeInBytes() local 187 vectorType.getElementType().cast<ScalarType>().getSizeInBytes(); in getSizeInBytes() 190 return *elementSize * vectorType.getNumElements(); in getSizeInBytes() 657 if (auto vectorType = type.dyn_cast<VectorType>()) in classof() local 658 return CompositeType::isValid(vectorType); in classof() 1130 if (auto vectorType = columnType.dyn_cast<VectorType>()) { in isValidColumnType() local [all …]
|
| /llvm-project-15.0.7/mlir/lib/Dialect/Quant/IR/ |
| H A D | QuantOps.cpp | 62 if (auto vectorType = expressed.dyn_cast<VectorType>()) in isValidQuantizationSpec() local 63 return spec == vectorType.getElementType(); in isValidQuantizationSpec()
|
| /llvm-project-15.0.7/mlir/lib/Dialect/SPIRV/Transforms/ |
| H A D | UnifyAliasedResourcePass.cpp | 87 if (auto vectorType = type.dyn_cast<VectorType>()) { in deduceCanonicalResource() local 88 if (vectorType.getNumElements() % 2 != 0) in deduceCanonicalResource() 96 vectorType.getElementType().getIntOrFloatBitWidth()); in deduceCanonicalResource() 466 auto vectorType = VectorType::get({ratio}, dstElemType); in matchAndRewrite() local 468 loc, vectorType, components); in matchAndRewrite()
|
| /llvm-project-15.0.7/mlir/include/mlir/Dialect/SPIRV/Utils/ |
| H A D | LayoutUtils.h | 67 static Type decorateType(VectorType vectorType, Size &size, Size &alignment);
|
| /llvm-project-15.0.7/mlir/include/mlir/Conversion/LLVMCommon/ |
| H A D | VectorPattern.h | 34 NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
|
| /llvm-project-15.0.7/mlir/include/mlir/Dialect/Vector/IR/ |
| H A D | VectorOps.h | 165 VectorType vectorType);
|