//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines the operations in the SPIR-V dialect. // //===----------------------------------------------------------------------===// #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/Dialect/SPIRV/IR/ParserUtils.h" #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/CallInterfaces.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/bit.h" using namespace mlir; // TODO: generate these strings using ODS. static constexpr const char kMemoryAccessAttrName[] = "memory_access"; static constexpr const char kSourceMemoryAccessAttrName[] = "source_memory_access"; static constexpr const char kAlignmentAttrName[] = "alignment"; static constexpr const char kSourceAlignmentAttrName[] = "source_alignment"; static constexpr const char kBranchWeightAttrName[] = "branch_weights"; static constexpr const char kCallee[] = "callee"; static constexpr const char kClusterSize[] = "cluster_size"; static constexpr const char kControl[] = "control"; static constexpr const char kDefaultValueAttrName[] = "default_value"; static constexpr const char kExecutionScopeAttrName[] = "execution_scope"; static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics"; static constexpr const char kFnNameAttrName[] = "fn"; static constexpr const char kGroupOperationAttrName[] = "group_operation"; static constexpr const char kIndicesAttrName[] = "indices"; static constexpr const char kInitializerAttrName[] = "initializer"; static constexpr const char kInterfaceAttrName[] = "interface"; static constexpr const char kMemoryScopeAttrName[] = "memory_scope"; static constexpr const char kSemanticsAttrName[] = "semantics"; static constexpr const char kSpecIdAttrName[] = "spec_id"; static constexpr const char kTypeAttrName[] = "type"; static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics"; static constexpr const char kValueAttrName[] = "value"; static constexpr const char kValuesAttrName[] = "values"; static constexpr const char kCompositeSpecConstituentsName[] = "constituents"; //===----------------------------------------------------------------------===// // Common utility functions //===----------------------------------------------------------------------===// static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, OperationState &result) { SmallVector ops; Type type; // If the operand list is in-between parentheses, then we have a generic form. // (see the fallback in `printOneResultOp`). SMLoc loc = parser.getCurrentLocation(); if (!parser.parseOptionalLParen()) { if (parser.parseOperandList(ops) || parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || parser.parseType(type)) return failure(); auto fnType = type.dyn_cast(); if (!fnType) { parser.emitError(loc, "expected function type"); return failure(); } if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) return failure(); result.addTypes(fnType.getResults()); return success(); } return failure(parser.parseOperandList(ops) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperands(ops, type, result.operands) || parser.addTypeToList(type, result.types)); } static void printOneResultOp(Operation *op, OpAsmPrinter &p) { assert(op->getNumResults() == 1 && "op should have one result"); // If not all the operand and result types are the same, just use the // generic assembly form to avoid omitting information in printing. auto resultType = op->getResult(0).getType(); if (llvm::any_of(op->getOperandTypes(), [&](Type type) { return type != resultType; })) { p.printGenericOp(op, /*printOpName=*/false); return; } p << ' '; p.printOperands(op->getOperands()); p.printOptionalAttrDict(op->getAttrs()); // Now we can output only one type for all operands and the result. p << " : " << resultType; } /// Returns true if the given op is a function-like op or nested in a /// function-like op without a module-like op in the middle. static bool isNestedInFunctionOpInterface(Operation *op) { if (!op) return false; if (op->hasTrait()) return false; if (isa(op)) return true; return isNestedInFunctionOpInterface(op->getParentOp()); } /// Returns true if the given op is an module-like op that maintains a symbol /// table. static bool isDirectInModuleLikeOp(Operation *op) { return op && op->hasTrait(); } static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) { auto constOp = dyn_cast_or_null(op); if (!constOp) { return failure(); } auto valueAttr = constOp.value(); auto integerValueAttr = valueAttr.dyn_cast(); if (!integerValueAttr) { return failure(); } if (integerValueAttr.getType().isSignlessInteger()) value = integerValueAttr.getInt(); else value = integerValueAttr.getSInt(); return success(); } template static ArrayAttr getStrArrayAttrForEnumList(Builder &builder, ArrayRef enumValues, function_ref stringifyFn) { if (enumValues.empty()) { return nullptr; } SmallVector enumValStrs; enumValStrs.reserve(enumValues.size()); for (auto val : enumValues) { enumValStrs.emplace_back(stringifyFn(val)); } return builder.getStrArrayAttr(enumValStrs); } /// Parses the next string attribute in `parser` as an enumerant of the given /// `EnumClass`. template static ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, StringRef attrName = spirv::attributeName()) { Attribute attrVal; NamedAttrList attr; auto loc = parser.getCurrentLocation(); if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), attrName, attr)) { return failure(); } if (!attrVal.isa()) { return parser.emitError(loc, "expected ") << attrName << " attribute specified as string"; } auto attrOptional = spirv::symbolizeEnum(attrVal.cast().getValue()); if (!attrOptional) { return parser.emitError(loc, "invalid ") << attrName << " attribute specification: " << attrVal; } value = *attrOptional; return success(); } /// Parses the next string attribute in `parser` as an enumerant of the given /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer /// attribute with the enum class's name as attribute name. template static ParseResult parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, StringRef attrName = spirv::attributeName()) { if (parseEnumStrAttr(value, parser)) { return failure(); } state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr( llvm::bit_cast(value))); return success(); } /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass` /// and inserts the enumerant into `state` as an 32-bit integer attribute with /// the enum class's name as attribute name. template static ParseResult parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, StringRef attrName = spirv::attributeName()) { if (parseEnumKeywordAttr(value, parser)) { return failure(); } state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr( llvm::bit_cast(value))); return success(); } /// Parses Function, Selection and Loop control attributes. If no control is /// specified, "None" is used as a default. template static ParseResult parseControlAttribute(OpAsmParser &parser, OperationState &state, StringRef attrName = spirv::attributeName()) { if (succeeded(parser.parseOptionalKeyword(kControl))) { EnumClass control; if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) || parser.parseRParen()) return failure(); return success(); } // Set control to "None" otherwise. Builder builder = parser.getBuilder(); state.addAttribute(attrName, builder.getI32IntegerAttr(0)); return success(); } /// Parses optional memory access attributes attached to a memory access /// operand/pointer. Specifically, parses the following syntax: /// (`[` memory-access `]`)? /// where: /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` /// integer-literal | `"NonTemporal"` static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, OperationState &state) { // Parse an optional list of attributes staring with '[' if (parser.parseOptionalLSquare()) { // Nothing to do return success(); } spirv::MemoryAccess memoryAccessAttr; if (parseEnumStrAttr(memoryAccessAttr, parser, state, kMemoryAccessAttrName)) { return failure(); } if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) { // Parse integer attribute for alignment. Attribute alignmentAttr; Type i32Type = parser.getBuilder().getIntegerType(32); if (parser.parseComma() || parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName, state.attributes)) { return failure(); } } return parser.parseRSquare(); } // TODO Make sure to merge this and the previous function into one template // parameterized by memory access attribute name and alignment. Doing so now // results in VS2017 in producing an internal error (at the call site) that's // not detailed enough to understand what is happening. static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, OperationState &state) { // Parse an optional list of attributes staring with '[' if (parser.parseOptionalLSquare()) { // Nothing to do return success(); } spirv::MemoryAccess memoryAccessAttr; if (parseEnumStrAttr(memoryAccessAttr, parser, state, kSourceMemoryAccessAttrName)) { return failure(); } if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) { // Parse integer attribute for alignment. Attribute alignmentAttr; Type i32Type = parser.getBuilder().getIntegerType(32); if (parser.parseComma() || parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName, state.attributes)) { return failure(); } } return parser.parseRSquare(); } template static void printMemoryAccessAttribute( MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl &elidedAttrs, Optional memoryAccessAtrrValue = None, Optional alignmentAttrValue = None) { // Print optional memory access attribute. if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue : memoryOp.memory_access())) { elidedAttrs.push_back(kMemoryAccessAttrName); printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { // Print integer alignment attribute. if (auto alignment = (alignmentAttrValue ? alignmentAttrValue : memoryOp.alignment())) { elidedAttrs.push_back(kAlignmentAttrName); printer << ", " << alignment; } } printer << "]"; } elidedAttrs.push_back(spirv::attributeName()); } // TODO Make sure to merge this and the previous function into one template // parameterized by memory access attribute name and alignment. Doing so now // results in VS2017 in producing an internal error (at the call site) that's // not detailed enough to understand what is happening. template static void printSourceMemoryAccessAttribute( MemoryOpTy memoryOp, OpAsmPrinter &printer, SmallVectorImpl &elidedAttrs, Optional memoryAccessAtrrValue = None, Optional alignmentAttrValue = None) { printer << ", "; // Print optional memory access attribute. if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue : memoryOp.memory_access())) { elidedAttrs.push_back(kSourceMemoryAccessAttrName); printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { // Print integer alignment attribute. if (auto alignment = (alignmentAttrValue ? alignmentAttrValue : memoryOp.alignment())) { elidedAttrs.push_back(kSourceAlignmentAttrName); printer << ", " << alignment; } } printer << "]"; } elidedAttrs.push_back(spirv::attributeName()); } static ParseResult parseImageOperands(OpAsmParser &parser, spirv::ImageOperandsAttr &attr) { // Expect image operands if (parser.parseOptionalLSquare()) return success(); spirv::ImageOperands imageOperands; if (parseEnumStrAttr(imageOperands, parser)) return failure(); attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands); return parser.parseRSquare(); } static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, spirv::ImageOperandsAttr attr) { if (attr) { auto strImageOperands = stringifyImageOperands(attr.getValue()); printer << "[\"" << strImageOperands << "\"]"; } } template static LogicalResult verifyImageOperands(Op imageOp, spirv::ImageOperandsAttr attr, Operation::operand_range operands) { if (!attr) { if (operands.empty()) return success(); return imageOp.emitError("the Image Operands should encode what operands " "follow, as per Image Operands"); } // TODO: Add the validation rules for the following Image Operands. spirv::ImageOperands noSupportOperands = spirv::ImageOperands::Bias | spirv::ImageOperands::Lod | spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset | spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets | spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod | spirv::ImageOperands::MakeTexelAvailable | spirv::ImageOperands::MakeTexelVisible | spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend; if (spirv::bitEnumContains(attr.getValue(), noSupportOperands)) llvm_unreachable("unimplemented operands of Image Operands"); return success(); } static LogicalResult verifyCastOp(Operation *op, bool requireSameBitWidth = true, bool skipBitWidthCheck = false) { // Some CastOps have no limit on bit widths for result and operand type. if (skipBitWidthCheck) return success(); Type operandType = op->getOperand(0).getType(); Type resultType = op->getResult(0).getType(); // ODS checks that result type and operand type have the same shape. if (auto vectorType = operandType.dyn_cast()) { operandType = vectorType.getElementType(); resultType = resultType.cast().getElementType(); } if (auto coopMatrixType = operandType.dyn_cast()) { operandType = coopMatrixType.getElementType(); resultType = resultType.cast().getElementType(); } auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth(); auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth(); auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; if (requireSameBitWidth) { if (!isSameBitWidth) { return op->emitOpError( "expected the same bit widths for operand type and result " "type, but provided ") << operandType << " and " << resultType; } return success(); } if (isSameBitWidth) { return op->emitOpError( "expected the different bit widths for operand type and result " "type, but provided ") << operandType << " and " << resultType; } return success(); } template static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { // ODS checks for attributes values. Just need to verify that if the // memory-access attribute is Aligned, then the alignment attribute must be // present. auto *op = memoryOp.getOperation(); auto memAccessAttr = op->getAttr(kMemoryAccessAttrName); if (!memAccessAttr) { // Alignment attribute shouldn't be present if memory access attribute is // not present. if (op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification without aligned memory access " "specification"); } return success(); } auto memAccessVal = memAccessAttr.template cast(); auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt()); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") << memAccessVal; } if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError("missing alignment value"); } } else { if (op->getAttr(kAlignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification with non-aligned memory access " "specification"); } } return success(); } // TODO Make sure to merge this and the previous function into one template // parameterized by memory access attribute name and alignment. Doing so now // results in VS2017 in producing an internal error (at the call site) that's // not detailed enough to understand what is happening. template static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { // ODS checks for attributes values. Just need to verify that if the // memory-access attribute is Aligned, then the alignment attribute must be // present. auto *op = memoryOp.getOperation(); auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName); if (!memAccessAttr) { // Alignment attribute shouldn't be present if memory access attribute is // not present. if (op->getAttr(kSourceAlignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification without aligned memory access " "specification"); } return success(); } auto memAccessVal = memAccessAttr.template cast(); auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt()); if (!memAccess) { return memoryOp.emitOpError("invalid memory access specifier: ") << memAccessVal; } if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) { if (!op->getAttr(kSourceAlignmentAttrName)) { return memoryOp.emitOpError("missing alignment value"); } } else { if (op->getAttr(kSourceAlignmentAttrName)) { return memoryOp.emitOpError( "invalid alignment specification with non-aligned memory access " "specification"); } } return success(); } static LogicalResult verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) { // According to the SPIR-V specification: // "Despite being a mask and allowing multiple bits to be combined, it is // invalid for more than one of these four bits to be set: Acquire, Release, // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and // Release semantics is done by setting the AcquireRelease bit, not by setting // two bits." auto atMostOneInSet = spirv::MemorySemantics::Acquire | spirv::MemorySemantics::Release | spirv::MemorySemantics::AcquireRelease | spirv::MemorySemantics::SequentiallyConsistent; auto bitCount = llvm::countPopulation( static_cast(memorySemantics & atMostOneInSet)); if (bitCount > 1) { return op->emitError( "expected at most one of these four memory constraints " "to be set: `Acquire`, `Release`," "`AcquireRelease` or `SequentiallyConsistent`"); } return success(); } template static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, Value val) { // ODS already checks ptr is spirv::PointerType. Just check that the pointee // type of the pointer and the type of the value are the same // // TODO: Check that the value type satisfies restrictions of // SPIR-V OpLoad/OpStore operations if (val.getType() != ptr.getType().cast().getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); } template static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, Value ptr, Value val) { auto valType = val.getType(); if (auto valVecTy = valType.dyn_cast()) valType = valVecTy.getElementType(); if (valType != ptr.getType().cast().getPointeeType()) { return op.emitOpError("mismatch in result type and pointer type"); } return success(); } static ParseResult parseVariableDecorations(OpAsmParser &parser, OperationState &state) { auto builtInName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::BuiltIn)); if (succeeded(parser.parseOptionalKeyword("bind"))) { Attribute set, binding; // Parse optional descriptor binding auto descriptorSetName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::DescriptorSet)); auto bindingName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::Binding)); Type i32Type = parser.getBuilder().getIntegerType(32); if (parser.parseLParen() || parser.parseAttribute(set, i32Type, descriptorSetName, state.attributes) || parser.parseComma() || parser.parseAttribute(binding, i32Type, bindingName, state.attributes) || parser.parseRParen()) { return failure(); } } else if (succeeded(parser.parseOptionalKeyword(builtInName))) { StringAttr builtIn; if (parser.parseLParen() || parser.parseAttribute(builtIn, builtInName, state.attributes) || parser.parseRParen()) { return failure(); } } // Parse other attributes if (parser.parseOptionalAttrDict(state.attributes)) return failure(); return success(); } static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, SmallVectorImpl &elidedAttrs) { // Print optional descriptor binding auto descriptorSetName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::DescriptorSet)); auto bindingName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::Binding)); auto descriptorSet = op->getAttrOfType(descriptorSetName); auto binding = op->getAttrOfType(bindingName); if (descriptorSet && binding) { elidedAttrs.push_back(descriptorSetName); elidedAttrs.push_back(bindingName); printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() << ")"; } // Print BuiltIn attribute if present auto builtInName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::BuiltIn)); if (auto builtin = op->getAttrOfType(builtInName)) { printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; elidedAttrs.push_back(builtInName); } printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } // Get bit width of types. static unsigned getBitWidth(Type type) { if (type.isa()) { // Just return 64 bits for pointer types for now. // TODO: Make sure not caller relies on the actual pointer width value. return 64; } if (type.isIntOrFloat()) return type.getIntOrFloatBitWidth(); if (auto vectorType = type.dyn_cast()) { assert(vectorType.getElementType().isIntOrFloat()); return vectorType.getNumElements() * vectorType.getElementType().getIntOrFloatBitWidth(); } llvm_unreachable("unhandled bit width computation for type"); } /// Walks the given type hierarchy with the given indices, potentially down /// to component granularity, to select an element type. Returns null type and /// emits errors with the given loc on failure. static Type getElementType(Type type, ArrayRef indices, function_ref emitErrorFn) { if (indices.empty()) { emitErrorFn("expected at least one index for spv.CompositeExtract"); return nullptr; } for (auto index : indices) { if (auto cType = type.dyn_cast()) { if (cType.hasCompileTimeKnownNumElements() && (index < 0 || static_cast(index) >= cType.getNumElements())) { emitErrorFn("index ") << index << " out of bounds for " << type; return nullptr; } type = cType.getElementType(index); } else { emitErrorFn("cannot extract from non-composite type ") << type << " with index " << index; return nullptr; } } return type; } static Type getElementType(Type type, Attribute indices, function_ref emitErrorFn) { auto indicesArrayAttr = indices.dyn_cast(); if (!indicesArrayAttr) { emitErrorFn("expected a 32-bit integer array attribute for 'indices'"); return nullptr; } if (indicesArrayAttr.empty()) { emitErrorFn("expected at least one index for spv.CompositeExtract"); return nullptr; } SmallVector indexVals; for (auto indexAttr : indicesArrayAttr) { auto indexIntAttr = indexAttr.dyn_cast(); if (!indexIntAttr) { emitErrorFn("expected an 32-bit integer for index, but found '") << indexAttr << "'"; return nullptr; } indexVals.push_back(indexIntAttr.getInt()); } return getElementType(type, indexVals, emitErrorFn); } static Type getElementType(Type type, Attribute indices, Location loc) { auto errorFn = [&](StringRef err) -> InFlightDiagnostic { return ::mlir::emitError(loc, err); }; return getElementType(type, indices, errorFn); } static Type getElementType(Type type, Attribute indices, OpAsmParser &parser, SMLoc loc) { auto errorFn = [&](StringRef err) -> InFlightDiagnostic { return parser.emitError(loc, err); }; return getElementType(type, indices, errorFn); } /// Returns true if the given `block` only contains one `spv.mlir.merge` op. static inline bool isMergeBlock(Block &block) { return !block.empty() && std::next(block.begin()) == block.end() && isa(block.front()); } //===----------------------------------------------------------------------===// // Common parsers and printers //===----------------------------------------------------------------------===// // Parses an atomic update op. If the update op does not take a value (like // AtomicIIncrement) `hasValue` must be false. static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, OperationState &state, bool hasValue) { spirv::Scope scope; spirv::MemorySemantics memoryScope; SmallVector operandInfo; OpAsmParser::UnresolvedOperand ptrInfo, valueInfo; Type type; SMLoc loc; if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) || parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) || parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) || parser.getCurrentLocation(&loc) || parser.parseColonType(type)) return failure(); auto ptrType = type.dyn_cast(); if (!ptrType) return parser.emitError(loc, "expected pointer type"); SmallVector operandTypes; operandTypes.push_back(ptrType); if (hasValue) operandTypes.push_back(ptrType.getPointeeType()); if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(), state.operands)) return failure(); return parser.addTypeToList(ptrType.getPointeeType(), state.types); } // Prints an atomic update op. static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) { printer << " \""; auto scopeAttr = op->getAttrOfType(kMemoryScopeAttrName); printer << spirv::stringifyScope( static_cast(scopeAttr.getInt())) << "\" \""; auto memorySemanticsAttr = op->getAttrOfType(kSemanticsAttrName); printer << spirv::stringifyMemorySemantics( static_cast( memorySemanticsAttr.getInt())) << "\" " << op->getOperands() << " : " << op->getOperand(0).getType(); } template static StringRef stringifyTypeName(); template <> StringRef stringifyTypeName() { return "integer"; } template <> StringRef stringifyTypeName() { return "float"; } // Verifies an atomic update op. template static LogicalResult verifyAtomicUpdateOp(Operation *op) { auto ptrType = op->getOperand(0).getType().cast(); auto elementType = ptrType.getPointeeType(); if (!elementType.isa()) return op->emitOpError() << "pointer operand must point to an " << stringifyTypeName() << " value, found " << elementType; if (op->getNumOperands() > 1) { auto valueType = op->getOperand(1).getType(); if (valueType != elementType) return op->emitOpError("expected value to have the same type as the " "pointer operand's pointee type ") << elementType << ", but found " << valueType; } auto memorySemantics = static_cast( op->getAttrOfType(kSemanticsAttrName).getInt()); if (failed(verifyMemorySemantics(op, memorySemantics))) { return failure(); } return success(); } static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, OperationState &state) { spirv::Scope executionScope; spirv::GroupOperation groupOperation; OpAsmParser::UnresolvedOperand valueInfo; if (parseEnumStrAttr(executionScope, parser, state, kExecutionScopeAttrName) || parseEnumStrAttr(groupOperation, parser, state, kGroupOperationAttrName) || parser.parseOperand(valueInfo)) return failure(); Optional clusterSizeInfo; if (succeeded(parser.parseOptionalKeyword(kClusterSize))) { clusterSizeInfo = OpAsmParser::UnresolvedOperand(); if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) || parser.parseRParen()) return failure(); } Type resultType; if (parser.parseColonType(resultType)) return failure(); if (parser.resolveOperand(valueInfo, resultType, state.operands)) return failure(); if (clusterSizeInfo) { Type i32Type = parser.getBuilder().getIntegerType(32); if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands)) return failure(); } return parser.addTypeToList(resultType, state.types); } static void printGroupNonUniformArithmeticOp(Operation *groupOp, OpAsmPrinter &printer) { printer << " \"" << stringifyScope(static_cast( groupOp->getAttrOfType(kExecutionScopeAttrName) .getInt())) << "\" \"" << stringifyGroupOperation(static_cast( groupOp->getAttrOfType(kGroupOperationAttrName) .getInt())) << "\" " << groupOp->getOperand(0); if (groupOp->getNumOperands() > 1) printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')'; printer << " : " << groupOp->getResult(0).getType(); } static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) { spirv::Scope scope = static_cast( groupOp->getAttrOfType(kExecutionScopeAttrName).getInt()); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return groupOp->emitOpError( "execution scope must be 'Workgroup' or 'Subgroup'"); spirv::GroupOperation operation = static_cast( groupOp->getAttrOfType(kGroupOperationAttrName).getInt()); if (operation == spirv::GroupOperation::ClusteredReduce && groupOp->getNumOperands() == 1) return groupOp->emitOpError("cluster size operand must be provided for " "'ClusteredReduce' group operation"); if (groupOp->getNumOperands() > 1) { Operation *sizeOp = groupOp->getOperand(1).getDefiningOp(); int32_t clusterSize = 0; // TODO: support specialization constant here. if (failed(extractValueFromConstOp(sizeOp, clusterSize))) return groupOp->emitOpError( "cluster size operand must come from a constant op"); if (!llvm::isPowerOf2_32(clusterSize)) return groupOp->emitOpError( "cluster size operand must be a power of two"); } return success(); } /// Result of a logical op must be a scalar or vector of boolean type. static Type getUnaryOpResultType(Type operandType) { Builder builder(operandType.getContext()); Type resultType = builder.getIntegerType(1); if (auto vecType = operandType.dyn_cast()) return VectorType::get(vecType.getNumElements(), resultType); return resultType; } static LogicalResult verifyShiftOp(Operation *op) { if (op->getOperand(0).getType() != op->getResult(0).getType()) { return op->emitError("expected the same type for the first operand and " "result, but provided ") << op->getOperand(0).getType() << " and " << op->getResult(0).getType(); } return success(); } static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state, Value lhs, Value rhs) { assert(lhs.getType() == rhs.getType()); Type boolType = builder.getI1Type(); if (auto vecType = lhs.getType().dyn_cast()) boolType = VectorType::get(vecType.getShape(), boolType); state.addTypes(boolType); state.addOperands({lhs, rhs}); } static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state, Value value) { Type boolType = builder.getI1Type(); if (auto vecType = value.getType().dyn_cast()) boolType = VectorType::get(vecType.getShape(), boolType); state.addTypes(boolType); state.addOperands(value); } //===----------------------------------------------------------------------===// // spv.AccessChainOp //===----------------------------------------------------------------------===// static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { auto ptrType = type.dyn_cast(); if (!ptrType) { emitError(baseLoc, "'spv.AccessChain' op expected a pointer " "to composite type, but provided ") << type; return nullptr; } auto resultType = ptrType.getPointeeType(); auto resultStorageClass = ptrType.getStorageClass(); int32_t index = 0; for (auto indexSSA : indices) { auto cType = resultType.dyn_cast(); if (!cType) { emitError(baseLoc, "'spv.AccessChain' op cannot extract from non-composite type ") << resultType << " with index " << index; return nullptr; } index = 0; if (resultType.isa()) { Operation *op = indexSSA.getDefiningOp(); if (!op) { emitError(baseLoc, "'spv.AccessChain' op index must be an " "integer spv.Constant to access " "element of spv.struct"); return nullptr; } // TODO: this should be relaxed to allow // integer literals of other bitwidths. if (failed(extractValueFromConstOp(op, index))) { emitError(baseLoc, "'spv.AccessChain' index must be an integer spv.Constant to " "access element of spv.struct, but provided ") << op->getName(); return nullptr; } if (index < 0 || static_cast(index) >= cType.getNumElements()) { emitError(baseLoc, "'spv.AccessChain' op index ") << index << " out of bounds for " << resultType; return nullptr; } } resultType = cType.getElementType(index); } return spirv::PointerType::get(resultType, resultStorageClass); } void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state, Value basePtr, ValueRange indices) { auto type = getElementPtrType(basePtr.getType(), indices, state.location); assert(type && "Unable to deduce return type based on basePtr and indices"); build(builder, state, type, basePtr, indices); } ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser, OperationState &state) { OpAsmParser::UnresolvedOperand ptrInfo; SmallVector indicesInfo; Type type; auto loc = parser.getCurrentLocation(); SmallVector indicesTypes; if (parser.parseOperand(ptrInfo) || parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || parser.parseColonType(type) || parser.resolveOperand(ptrInfo, type, state.operands)) { return failure(); } // Check that the provided indices list is not empty before parsing their // type list. if (indicesInfo.empty()) { return mlir::emitError(state.location, "'spv.AccessChain' op expected at " "least one index "); } if (parser.parseComma() || parser.parseTypeList(indicesTypes)) return failure(); // Check that the indices types list is not empty and that it has a one-to-one // mapping to the provided indices. if (indicesTypes.size() != indicesInfo.size()) { return mlir::emitError(state.location, "'spv.AccessChain' op indices types' count must be " "equal to indices info count"); } if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) return failure(); auto resultType = getElementPtrType( type, llvm::makeArrayRef(state.operands).drop_front(), state.location); if (!resultType) { return failure(); } state.addTypes(resultType); return success(); } template static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { printer << ' ' << op.base_ptr() << '[' << indices << "] : " << op.base_ptr().getType() << ", " << indices.getTypes(); } void spirv::AccessChainOp::print(OpAsmPrinter &printer) { printAccessChain(*this, indices(), printer); } template static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(), indices, accessChainOp.getLoc()); if (!resultType) return failure(); auto providedResultType = accessChainOp.getType().template dyn_cast(); if (!providedResultType) return accessChainOp.emitOpError( "result type must be a pointer, but provided") << providedResultType; if (resultType != providedResultType) return accessChainOp.emitOpError("invalid result type: expected ") << resultType << ", but provided " << providedResultType; return success(); } LogicalResult spirv::AccessChainOp::verify() { return verifyAccessChain(*this, indices()); } //===----------------------------------------------------------------------===// // spv.mlir.addressof //===----------------------------------------------------------------------===// void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state, spirv::GlobalVariableOp var) { build(builder, state, var.type(), SymbolRefAttr::get(var)); } LogicalResult spirv::AddressOfOp::verify() { auto varOp = dyn_cast_or_null( SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), variableAttr())); if (!varOp) { return emitOpError("expected spv.GlobalVariable symbol"); } if (pointer().getType() != varOp.type()) { return emitOpError( "result type mismatch with the referenced global variable's type"); } return success(); } template static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) { printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \"" << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \"" << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" " << atomOp.getOperands() << " : " << atomOp.pointer().getType(); } static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, OperationState &state) { spirv::Scope memoryScope; spirv::MemorySemantics equalSemantics, unequalSemantics; SmallVector operandInfo; Type type; if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) || parseEnumStrAttr(equalSemantics, parser, state, kEqualSemanticsAttrName) || parseEnumStrAttr(unequalSemantics, parser, state, kUnequalSemanticsAttrName) || parser.parseOperandList(operandInfo, 3)) return failure(); auto loc = parser.getCurrentLocation(); if (parser.parseColonType(type)) return failure(); auto ptrType = type.dyn_cast(); if (!ptrType) return parser.emitError(loc, "expected pointer type"); if (parser.resolveOperands( operandInfo, {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()}, parser.getNameLoc(), state.operands)) return failure(); return parser.addTypeToList(ptrType.getPointeeType(), state.types); } template static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) { // According to the spec: // "The type of Value must be the same as Result Type. The type of the value // pointed to by Pointer must be the same as Result Type. This type must also // match the type of Comparator." if (atomOp.getType() != atomOp.value().getType()) return atomOp.emitOpError("value operand must have the same type as the op " "result, but found ") << atomOp.value().getType() << " vs " << atomOp.getType(); if (atomOp.getType() != atomOp.comparator().getType()) return atomOp.emitOpError( "comparator operand must have the same type as the op " "result, but found ") << atomOp.comparator().getType() << " vs " << atomOp.getType(); Type pointeeType = atomOp.pointer() .getType() .template cast() .getPointeeType(); if (atomOp.getType() != pointeeType) return atomOp.emitOpError( "pointer operand's pointee type must have the same " "as the op result type, but found ") << pointeeType << " vs " << atomOp.getType(); // TODO: Unequal cannot be set to Release or Acquire and Release. // In addition, Unequal cannot be set to a stronger memory-order then Equal. return success(); } //===----------------------------------------------------------------------===// // spv.AtomicAndOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicAndOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } void spirv::AtomicAndOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicCompareExchangeOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicCompareExchangeOp::verify() { return ::verifyAtomicCompareExchangeImpl(*this); } ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicCompareExchangeImpl(parser, result); } void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) { ::printAtomicCompareExchangeImpl(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicCompareExchangeWeakOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() { return ::verifyAtomicCompareExchangeImpl(*this); } ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicCompareExchangeImpl(parser, result); } void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) { ::printAtomicCompareExchangeImpl(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicExchange //===----------------------------------------------------------------------===// void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) { printer << " \"" << stringifyScope(memory_scope()) << "\" \"" << stringifyMemorySemantics(semantics()) << "\" " << getOperands() << " : " << pointer().getType(); } ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser, OperationState &state) { spirv::Scope memoryScope; spirv::MemorySemantics semantics; SmallVector operandInfo; Type type; if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) || parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) || parser.parseOperandList(operandInfo, 2)) return failure(); auto loc = parser.getCurrentLocation(); if (parser.parseColonType(type)) return failure(); auto ptrType = type.dyn_cast(); if (!ptrType) return parser.emitError(loc, "expected pointer type"); if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()}, parser.getNameLoc(), state.operands)) return failure(); return parser.addTypeToList(ptrType.getPointeeType(), state.types); } LogicalResult spirv::AtomicExchangeOp::verify() { if (getType() != value().getType()) return emitOpError("value operand must have the same type as the op " "result, but found ") << value().getType() << " vs " << getType(); Type pointeeType = pointer().getType().cast().getPointeeType(); if (getType() != pointeeType) return emitOpError("pointer operand's pointee type must have the same " "as the op result type, but found ") << pointeeType << " vs " << getType(); return success(); } //===----------------------------------------------------------------------===// // spv.AtomicIAddOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicIAddOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } void spirv::AtomicIAddOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicFAddEXTOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicFAddEXTOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicIDecrementOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicIDecrementOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, false); } void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicIIncrementOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicIIncrementOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, false); } void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicISubOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicISubOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } void spirv::AtomicISubOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicOrOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicOrOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } void spirv::AtomicOrOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicSMaxOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicSMaxOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicSMinOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicSMinOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } void spirv::AtomicSMinOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicUMaxOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicUMaxOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicUMinOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicUMinOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } void spirv::AtomicUMinOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.AtomicXorOp //===----------------------------------------------------------------------===// LogicalResult spirv::AtomicXorOp::verify() { return ::verifyAtomicUpdateOp(getOperation()); } ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser, OperationState &result) { return ::parseAtomicUpdateOp(parser, result, true); } void spirv::AtomicXorOp::print(OpAsmPrinter &p) { ::printAtomicUpdateOp(*this, p); } //===----------------------------------------------------------------------===// // spv.BitcastOp //===----------------------------------------------------------------------===// LogicalResult spirv::BitcastOp::verify() { // TODO: The SPIR-V spec validation rules are different for different // versions. auto operandType = operand().getType(); auto resultType = result().getType(); if (operandType == resultType) { return emitError("result type must be different from operand type"); } if (operandType.isa() && !resultType.isa()) { return emitError( "unhandled bit cast conversion from pointer type to non-pointer type"); } if (!operandType.isa() && resultType.isa()) { return emitError( "unhandled bit cast conversion from non-pointer type to pointer type"); } auto operandBitWidth = getBitWidth(operandType); auto resultBitWidth = getBitWidth(resultType); if (operandBitWidth != resultBitWidth) { return emitOpError("mismatch in result type bitwidth ") << resultBitWidth << " and operand type bitwidth " << operandBitWidth; } return success(); } //===----------------------------------------------------------------------===// // spv.BranchOp //===----------------------------------------------------------------------===// SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); return SuccessorOperands(0, targetOperandsMutable()); } //===----------------------------------------------------------------------===// // spv.BranchConditionalOp //===----------------------------------------------------------------------===// SuccessorOperands spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { assert(index < 2 && "invalid successor index"); return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable() : falseTargetOperandsMutable()); } ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser, OperationState &state) { auto &builder = parser.getBuilder(); OpAsmParser::UnresolvedOperand condInfo; Block *dest; // Parse the condition. Type boolTy = builder.getI1Type(); if (parser.parseOperand(condInfo) || parser.resolveOperand(condInfo, boolTy, state.operands)) return failure(); // Parse the optional branch weights. if (succeeded(parser.parseOptionalLSquare())) { IntegerAttr trueWeight, falseWeight; NamedAttrList weights; auto i32Type = builder.getIntegerType(32); if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) || parser.parseComma() || parser.parseAttribute(falseWeight, i32Type, "weight", weights) || parser.parseRSquare()) return failure(); state.addAttribute(kBranchWeightAttrName, builder.getArrayAttr({trueWeight, falseWeight})); } // Parse the true branch. SmallVector trueOperands; if (parser.parseComma() || parser.parseSuccessorAndUseList(dest, trueOperands)) return failure(); state.addSuccessors(dest); state.addOperands(trueOperands); // Parse the false branch. SmallVector falseOperands; if (parser.parseComma() || parser.parseSuccessorAndUseList(dest, falseOperands)) return failure(); state.addSuccessors(dest); state.addOperands(falseOperands); state.addAttribute( spirv::BranchConditionalOp::getOperandSegmentSizeAttr(), builder.getI32VectorAttr({1, static_cast(trueOperands.size()), static_cast(falseOperands.size())})); return success(); } void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) { printer << ' ' << condition(); if (auto weights = branch_weights()) { printer << " ["; llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { printer << a.cast().getInt(); }); printer << "]"; } printer << ", "; printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments()); printer << ", "; printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments()); } LogicalResult spirv::BranchConditionalOp::verify() { if (auto weights = branch_weights()) { if (weights->getValue().size() != 2) { return emitOpError("must have exactly two branch weights"); } if (llvm::all_of(*weights, [](Attribute attr) { return attr.cast().getValue().isNullValue(); })) return emitOpError("branch weights cannot both be zero"); } return success(); } //===----------------------------------------------------------------------===// // spv.CompositeConstruct //===----------------------------------------------------------------------===// ParseResult spirv::CompositeConstructOp::parse(OpAsmParser &parser, OperationState &state) { SmallVector operands; Type type; auto loc = parser.getCurrentLocation(); if (parser.parseOperandList(operands) || parser.parseColonType(type)) { return failure(); } auto cType = type.dyn_cast(); if (!cType) { return parser.emitError( loc, "result type must be a composite type, but provided ") << type; } if (cType.hasCompileTimeKnownNumElements() && operands.size() != cType.getNumElements()) { return parser.emitError(loc, "has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << operands.size(); } // TODO: Add support for constructing a vector type from the vector operands. // According to the spec: "for constructing a vector, the operands may // also be vectors with the same component type as the Result Type component // type". SmallVector elementTypes; elementTypes.reserve(operands.size()); for (auto index : llvm::seq(0, operands.size())) { elementTypes.push_back(cType.getElementType(index)); } state.addTypes(type); return parser.resolveOperands(operands, elementTypes, loc, state.operands); } void spirv::CompositeConstructOp::print(OpAsmPrinter &printer) { printer << " " << constituents() << " : " << getResult().getType(); } LogicalResult spirv::CompositeConstructOp::verify() { auto cType = getType().cast(); operand_range constituents = this->constituents(); if (cType.isa()) { if (constituents.size() != 1) return emitError("has incorrect number of operands: expected ") << "1, but provided " << constituents.size(); } else if (constituents.size() != cType.getNumElements()) { return emitError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << constituents.size(); } for (auto index : llvm::seq(0, constituents.size())) { if (constituents[index].getType() != cType.getElementType(index)) { return emitError("operand type mismatch: expected operand type ") << cType.getElementType(index) << ", but provided " << constituents[index].getType(); } } return success(); } //===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===// void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state, Value composite, ArrayRef indices) { auto indexAttr = builder.getI32ArrayAttr(indices); auto elementType = getElementType(composite.getType(), indexAttr, state.location); if (!elementType) { return; } build(builder, state, elementType, composite, indexAttr); } ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser, OperationState &state) { OpAsmParser::UnresolvedOperand compositeInfo; Attribute indicesAttr; Type compositeType; SMLoc attrLocation; if (parser.parseOperand(compositeInfo) || parser.getCurrentLocation(&attrLocation) || parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) || parser.parseColonType(compositeType) || parser.resolveOperand(compositeInfo, compositeType, state.operands)) { return failure(); } Type resultType = getElementType(compositeType, indicesAttr, parser, attrLocation); if (!resultType) { return failure(); } state.addTypes(resultType); return success(); } void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) { printer << ' ' << composite() << indices() << " : " << composite().getType(); } LogicalResult spirv::CompositeExtractOp::verify() { auto indicesArrayAttr = indices().dyn_cast(); auto resultType = getElementType(composite().getType(), indicesArrayAttr, getLoc()); if (!resultType) return failure(); if (resultType != getType()) { return emitOpError("invalid result type: expected ") << resultType << " but provided " << getType(); } return success(); } //===----------------------------------------------------------------------===// // spv.CompositeInsert //===----------------------------------------------------------------------===// void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state, Value object, Value composite, ArrayRef indices) { auto indexAttr = builder.getI32ArrayAttr(indices); build(builder, state, composite.getType(), object, composite, indexAttr); } ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser, OperationState &state) { SmallVector operands; Type objectType, compositeType; Attribute indicesAttr; auto loc = parser.getCurrentLocation(); return failure( parser.parseOperandList(operands, 2) || parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) || parser.parseColonType(objectType) || parser.parseKeywordType("into", compositeType) || parser.resolveOperands(operands, {objectType, compositeType}, loc, state.operands) || parser.addTypesToList(compositeType, state.types)); } LogicalResult spirv::CompositeInsertOp::verify() { auto indicesArrayAttr = indices().dyn_cast(); auto objectType = getElementType(composite().getType(), indicesArrayAttr, getLoc()); if (!objectType) return failure(); if (objectType != object().getType()) { return emitOpError("object operand type should be ") << objectType << ", but found " << object().getType(); } if (composite().getType() != getType()) { return emitOpError("result type should be the same as " "the composite type, but found ") << composite().getType() << " vs " << getType(); } return success(); } void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) { printer << " " << object() << ", " << composite() << indices() << " : " << object().getType() << " into " << composite().getType(); } //===----------------------------------------------------------------------===// // spv.Constant //===----------------------------------------------------------------------===// ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, OperationState &state) { Attribute value; if (parser.parseAttribute(value, kValueAttrName, state.attributes)) return failure(); Type type = value.getType(); if (type.isa()) { if (parser.parseColonType(type)) return failure(); } return parser.addTypeToList(type, state.types); } void spirv::ConstantOp::print(OpAsmPrinter &printer) { printer << ' ' << value(); if (getType().isa()) printer << " : " << getType(); } static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, Type opType) { auto valueType = value.getType(); if (value.isa()) { if (valueType != opType) return op.emitOpError("result type (") << opType << ") does not match value type (" << valueType << ")"; return success(); } if (value.isa()) { if (valueType == opType) return success(); auto arrayType = opType.dyn_cast(); auto shapedType = valueType.dyn_cast(); if (!arrayType) return op.emitOpError("result or element type (") << opType << ") does not match value type (" << valueType << "), must be the same or spv.array"; int numElements = arrayType.getNumElements(); auto opElemType = arrayType.getElementType(); while (auto t = opElemType.dyn_cast()) { numElements *= t.getNumElements(); opElemType = t.getElementType(); } if (!opElemType.isIntOrFloat()) return op.emitOpError("only support nested array result type"); auto valueElemType = shapedType.getElementType(); if (valueElemType != opElemType) { return op.emitOpError("result element type (") << opElemType << ") does not match value element type (" << valueElemType << ")"; } if (numElements != shapedType.getNumElements()) { return op.emitOpError("result number of elements (") << numElements << ") does not match value number of elements (" << shapedType.getNumElements() << ")"; } return success(); } if (auto arrayAttr = value.dyn_cast()) { auto arrayType = opType.dyn_cast(); if (!arrayType) return op.emitOpError("must have spv.array result type for array value"); Type elemType = arrayType.getElementType(); for (Attribute element : arrayAttr.getValue()) { // Verify array elements recursively. if (failed(verifyConstantType(op, element, elemType))) return failure(); } return success(); } return op.emitOpError("cannot have value of type ") << valueType; } LogicalResult spirv::ConstantOp::verify() { // ODS already generates checks to make sure the result type is valid. We just // need to additionally check that the value's attribute type is consistent // with the result type. return verifyConstantType(*this, valueAttr(), getType()); } bool spirv::ConstantOp::isBuildableWith(Type type) { // Must be valid SPIR-V type first. if (!type.isa()) return false; if (isa(type.getDialect())) { // TODO: support constant struct return type.isa(); } return true; } spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, OpBuilder &builder) { if (auto intType = type.dyn_cast()) { unsigned width = intType.getWidth(); if (width == 1) return builder.create(loc, type, builder.getBoolAttr(false)); return builder.create( loc, type, builder.getIntegerAttr(type, APInt(width, 0))); } if (auto floatType = type.dyn_cast()) { return builder.create( loc, type, builder.getFloatAttr(floatType, 0.0)); } if (auto vectorType = type.dyn_cast()) { Type elemType = vectorType.getElementType(); if (elemType.isa()) { return builder.create( loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 0.0).getValue())); } if (elemType.isa()) { return builder.create( loc, type, DenseFPElementsAttr::get(vectorType, FloatAttr::get(elemType, 0.0).getValue())); } } llvm_unreachable("unimplemented types for ConstantOp::getZero()"); } spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, OpBuilder &builder) { if (auto intType = type.dyn_cast()) { unsigned width = intType.getWidth(); if (width == 1) return builder.create(loc, type, builder.getBoolAttr(true)); return builder.create( loc, type, builder.getIntegerAttr(type, APInt(width, 1))); } if (auto floatType = type.dyn_cast()) { return builder.create( loc, type, builder.getFloatAttr(floatType, 1.0)); } if (auto vectorType = type.dyn_cast()) { Type elemType = vectorType.getElementType(); if (elemType.isa()) { return builder.create( loc, type, DenseElementsAttr::get(vectorType, IntegerAttr::get(elemType, 1.0).getValue())); } if (elemType.isa()) { return builder.create( loc, type, DenseFPElementsAttr::get(vectorType, FloatAttr::get(elemType, 1.0).getValue())); } } llvm_unreachable("unimplemented types for ConstantOp::getOne()"); } void mlir::spirv::ConstantOp::getAsmResultNames( llvm::function_ref setNameFn) { Type type = getType(); SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << "cst"; IntegerType intTy = type.dyn_cast(); if (IntegerAttr intCst = value().dyn_cast()) { if (intTy && intTy.getWidth() == 1) { return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); } if (intTy.isSignless()) { specialName << intCst.getInt(); } else { specialName << intCst.getSInt(); } } if (intTy || type.isa()) { specialName << '_' << type; } if (auto vecType = type.dyn_cast()) { specialName << "_vec_"; specialName << vecType.getDimSize(0); Type elementType = vecType.getElementType(); if (elementType.isa() || elementType.isa()) { specialName << "x" << elementType; } } setNameFn(getResult(), specialName.str()); } void mlir::spirv::AddressOfOp::getAsmResultNames( llvm::function_ref setNameFn) { SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << variable() << "_addr"; setNameFn(getResult(), specialName.str()); } //===----------------------------------------------------------------------===// // spv.ControlBarrierOp //===----------------------------------------------------------------------===// LogicalResult spirv::ControlBarrierOp::verify() { return verifyMemorySemantics(getOperation(), memory_semantics()); } //===----------------------------------------------------------------------===// // spv.ConvertFToSOp //===----------------------------------------------------------------------===// LogicalResult spirv::ConvertFToSOp::verify() { return verifyCastOp(*this, /*requireSameBitWidth=*/false, /*skipBitWidthCheck=*/true); } //===----------------------------------------------------------------------===// // spv.ConvertFToUOp //===----------------------------------------------------------------------===// LogicalResult spirv::ConvertFToUOp::verify() { return verifyCastOp(*this, /*requireSameBitWidth=*/false, /*skipBitWidthCheck=*/true); } //===----------------------------------------------------------------------===// // spv.ConvertSToFOp //===----------------------------------------------------------------------===// LogicalResult spirv::ConvertSToFOp::verify() { return verifyCastOp(*this, /*requireSameBitWidth=*/false, /*skipBitWidthCheck=*/true); } //===----------------------------------------------------------------------===// // spv.ConvertUToFOp //===----------------------------------------------------------------------===// LogicalResult spirv::ConvertUToFOp::verify() { return verifyCastOp(*this, /*requireSameBitWidth=*/false, /*skipBitWidthCheck=*/true); } //===----------------------------------------------------------------------===// // spv.EntryPoint //===----------------------------------------------------------------------===// void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state, spirv::ExecutionModel executionModel, spirv::FuncOp function, ArrayRef interfaceVars) { build(builder, state, spirv::ExecutionModelAttr::get(builder.getContext(), executionModel), SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars)); } ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser, OperationState &state) { spirv::ExecutionModel execModel; SmallVector identifiers; SmallVector idTypes; SmallVector interfaceVars; FlatSymbolRefAttr fn; if (parseEnumStrAttr(execModel, parser, state) || parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) { return failure(); } if (!parser.parseOptionalComma()) { // Parse the interface variables if (parser.parseCommaSeparatedList([&]() -> ParseResult { // The name of the interface variable attribute isnt important FlatSymbolRefAttr var; NamedAttrList attrs; if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) return failure(); interfaceVars.push_back(var); return success(); })) return failure(); } state.addAttribute(kInterfaceAttrName, parser.getBuilder().getArrayAttr(interfaceVars)); return success(); } void spirv::EntryPointOp::print(OpAsmPrinter &printer) { printer << " \"" << stringifyExecutionModel(execution_model()) << "\" "; printer.printSymbolName(fn()); auto interfaceVars = interface().getValue(); if (!interfaceVars.empty()) { printer << ", "; llvm::interleaveComma(interfaceVars, printer); } } LogicalResult spirv::EntryPointOp::verify() { // Checks for fn and interface symbol reference are done in spirv::ModuleOp // verification. return success(); } //===----------------------------------------------------------------------===// // spv.ExecutionMode //===----------------------------------------------------------------------===// void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state, spirv::FuncOp function, spirv::ExecutionMode executionMode, ArrayRef params) { build(builder, state, SymbolRefAttr::get(function), spirv::ExecutionModeAttr::get(builder.getContext(), executionMode), builder.getI32ArrayAttr(params)); } ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser, OperationState &state) { spirv::ExecutionMode execMode; Attribute fn; if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) || parseEnumStrAttr(execMode, parser, state)) { return failure(); } SmallVector values; Type i32Type = parser.getBuilder().getIntegerType(32); while (!parser.parseOptionalComma()) { NamedAttrList attr; Attribute value; if (parser.parseAttribute(value, i32Type, "value", attr)) { return failure(); } values.push_back(value.cast().getInt()); } state.addAttribute(kValuesAttrName, parser.getBuilder().getI32ArrayAttr(values)); return success(); } void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) { printer << " "; printer.printSymbolName(fn()); printer << " \"" << stringifyExecutionMode(execution_mode()) << "\""; auto values = this->values(); if (values.empty()) return; printer << ", "; llvm::interleaveComma(values, printer, [&](Attribute a) { printer << a.cast().getInt(); }); } //===----------------------------------------------------------------------===// // spv.FConvertOp //===----------------------------------------------------------------------===// LogicalResult spirv::FConvertOp::verify() { return verifyCastOp(*this, /*requireSameBitWidth=*/false); } //===----------------------------------------------------------------------===// // spv.SConvertOp //===----------------------------------------------------------------------===// LogicalResult spirv::SConvertOp::verify() { return verifyCastOp(*this, /*requireSameBitWidth=*/false); } //===----------------------------------------------------------------------===// // spv.UConvertOp //===----------------------------------------------------------------------===// LogicalResult spirv::UConvertOp::verify() { return verifyCastOp(*this, /*requireSameBitWidth=*/false); } //===----------------------------------------------------------------------===// // spv.func //===----------------------------------------------------------------------===// ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) { SmallVector entryArgs; SmallVector resultAttrs; SmallVector resultTypes; auto &builder = parser.getBuilder(); // Parse the name as a symbol. StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), state.attributes)) return failure(); // Parse the function signature. bool isVariadic = false; if (function_interface_impl::parseFunctionSignature( parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, resultAttrs)) return failure(); SmallVector argTypes; for (auto &arg : entryArgs) argTypes.push_back(arg.type); auto fnType = builder.getFunctionType(argTypes, resultTypes); state.addAttribute(FunctionOpInterface::getTypeAttrName(), TypeAttr::get(fnType)); // Parse the optional function control keyword. spirv::FunctionControl fnControl; if (parseEnumStrAttr(fnControl, parser, state)) return failure(); // If additional attributes are present, parse them. if (parser.parseOptionalAttrDictWithKeyword(state.attributes)) return failure(); // Add the attributes to the function arguments. assert(resultAttrs.size() == resultTypes.size()); function_interface_impl::addArgAndResultAttrs(builder, state, entryArgs, resultAttrs); // Parse the optional function body. auto *body = state.addRegion(); OptionalParseResult result = parser.parseOptionalRegion(*body, entryArgs); return failure(result.hasValue() && failed(*result)); } void spirv::FuncOp::print(OpAsmPrinter &printer) { // Print function name, signature, and control. printer << " "; printer.printSymbolName(sym_name()); auto fnType = getFunctionType(); function_interface_impl::printFunctionSignature( printer, *this, fnType.getInputs(), /*isVariadic=*/false, fnType.getResults()); printer << " \"" << spirv::stringifyFunctionControl(function_control()) << "\""; function_interface_impl::printFunctionAttributes( printer, *this, fnType.getNumInputs(), fnType.getNumResults(), {spirv::attributeName()}); // Print the body if this is not an external function. Region &body = this->body(); if (!body.empty()) { printer << ' '; printer.printRegion(body, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } } LogicalResult spirv::FuncOp::verifyType() { auto type = getFunctionTypeAttr().getValue(); if (!type.isa()) return emitOpError("requires '" + getTypeAttrName() + "' attribute of function type"); if (getFunctionType().getNumResults() > 1) return emitOpError("cannot have more than one result"); return success(); } LogicalResult spirv::FuncOp::verifyBody() { FunctionType fnType = getFunctionType(); auto walkResult = walk([fnType](Operation *op) -> WalkResult { if (auto retOp = dyn_cast(op)) { if (fnType.getNumResults() != 0) return retOp.emitOpError("cannot be used in functions returning value"); } else if (auto retOp = dyn_cast(op)) { if (fnType.getNumResults() != 1) return retOp.emitOpError( "returns 1 value but enclosing function requires ") << fnType.getNumResults() << " results"; auto retOperandType = retOp.value().getType(); auto fnResultType = fnType.getResult(0); if (retOperandType != fnResultType) return retOp.emitOpError(" return value's type (") << retOperandType << ") mismatch with function's result type (" << fnResultType << ")"; } return WalkResult::advance(); }); // TODO: verify other bits like linkage type. return failure(walkResult.wasInterrupted()); } void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name, FunctionType type, spirv::FunctionControl control, ArrayRef attrs) { state.addAttribute(SymbolTable::getSymbolAttrName(), builder.getStringAttr(name)); state.addAttribute(getTypeAttrName(), TypeAttr::get(type)); state.addAttribute(spirv::attributeName(), builder.getI32IntegerAttr(static_cast(control))); state.attributes.append(attrs.begin(), attrs.end()); state.addRegion(); } // CallableOpInterface Region *spirv::FuncOp::getCallableRegion() { return isExternal() ? nullptr : &body(); } // CallableOpInterface ArrayRef spirv::FuncOp::getCallableResults() { return getFunctionType().getResults(); } //===----------------------------------------------------------------------===// // spv.FunctionCall //===----------------------------------------------------------------------===// LogicalResult spirv::FunctionCallOp::verify() { auto fnName = calleeAttr(); auto funcOp = dyn_cast_or_null( SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); if (!funcOp) { return emitOpError("callee function '") << fnName.getValue() << "' not found in nearest symbol table"; } auto functionType = funcOp.getFunctionType(); if (getNumResults() > 1) { return emitOpError( "expected callee function to have 0 or 1 result, but provided ") << getNumResults(); } if (functionType.getNumInputs() != getNumOperands()) { return emitOpError("has incorrect number of operands for callee: expected ") << functionType.getNumInputs() << ", but provided " << getNumOperands(); } for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { if (getOperand(i).getType() != functionType.getInput(i)) { return emitOpError("operand type mismatch: expected operand type ") << functionType.getInput(i) << ", but provided " << getOperand(i).getType() << " for operand number " << i; } } if (functionType.getNumResults() != getNumResults()) { return emitOpError( "has incorrect number of results has for callee: expected ") << functionType.getNumResults() << ", but provided " << getNumResults(); } if (getNumResults() && (getResult(0).getType() != functionType.getResult(0))) { return emitOpError("result type mismatch: expected ") << functionType.getResult(0) << ", but provided " << getResult(0).getType(); } return success(); } CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() { return (*this)->getAttrOfType(kCallee); } Operation::operand_range spirv::FunctionCallOp::getArgOperands() { return arguments(); } //===----------------------------------------------------------------------===// // spv.GLFClampOp //===----------------------------------------------------------------------===// ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); } void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GLUClampOp //===----------------------------------------------------------------------===// ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); } void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GLSClampOp //===----------------------------------------------------------------------===// ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); } void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GLFmaOp //===----------------------------------------------------------------------===// ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) { return parseOneResultSameOperandTypeOp(parser, result); } void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GlobalVariable //===----------------------------------------------------------------------===// void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, Type type, StringRef name, unsigned descriptorSet, unsigned binding) { build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); state.addAttribute( spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet), builder.getI32IntegerAttr(descriptorSet)); state.addAttribute( spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding), builder.getI32IntegerAttr(binding)); } void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, Type type, StringRef name, spirv::BuiltIn builtin) { build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); state.addAttribute( spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn), builder.getStringAttr(spirv::stringifyBuiltIn(builtin))); } ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser, OperationState &state) { // Parse variable name. StringAttr nameAttr; if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), state.attributes)) { return failure(); } // Parse optional initializer if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) { FlatSymbolRefAttr initSymbol; if (parser.parseLParen() || parser.parseAttribute(initSymbol, Type(), kInitializerAttrName, state.attributes) || parser.parseRParen()) return failure(); } if (parseVariableDecorations(parser, state)) { return failure(); } Type type; auto loc = parser.getCurrentLocation(); if (parser.parseColonType(type)) { return failure(); } if (!type.isa()) { return parser.emitError(loc, "expected spv.ptr type"); } state.addAttribute(kTypeAttrName, TypeAttr::get(type)); return success(); } void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs{ spirv::attributeName()}; // Print variable name. printer << ' '; printer.printSymbolName(sym_name()); elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); // Print optional initializer if (auto initializer = this->initializer()) { printer << " " << kInitializerAttrName << '('; printer.printSymbolName(*initializer); printer << ')'; elidedAttrs.push_back(kInitializerAttrName); } elidedAttrs.push_back(kTypeAttrName); printVariableDecorations(*this, printer, elidedAttrs); printer << " : " << type(); } LogicalResult spirv::GlobalVariableOp::verify() { // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the // object. It cannot be Generic. It must be the same as the Storage Class // operand of the Result Type." // Also, Function storage class is reserved by spv.Variable. auto storageClass = this->storageClass(); if (storageClass == spirv::StorageClass::Generic || storageClass == spirv::StorageClass::Function) { return emitOpError("storage class cannot be '") << stringifyStorageClass(storageClass) << "'"; } if (auto init = (*this)->getAttrOfType(kInitializerAttrName)) { Operation *initOp = SymbolTable::lookupNearestSymbolFrom( (*this)->getParentOp(), init.getAttr()); // TODO: Currently only variable initialization with specialization // constants and other variables is supported. They could be normal // constants in the module scope as well. if (!initOp || !isa(initOp)) { return emitOpError("initializer must be result of a " "spv.SpecConstant or spv.GlobalVariable op"); } } return success(); } //===----------------------------------------------------------------------===// // spv.GroupBroadcast //===----------------------------------------------------------------------===// LogicalResult spirv::GroupBroadcastOp::verify() { spirv::Scope scope = execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); if (auto localIdTy = localid().getType().dyn_cast()) if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) return emitOpError("localid is a vector and can be with only " " 2 or 3 components, actual number is ") << localIdTy.getNumElements(); return success(); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformBallotOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformBallotOp::verify() { spirv::Scope scope = execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); return success(); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformBroadcast //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformBroadcastOp::verify() { spirv::Scope scope = execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); // SPIR-V spec: "Before version 1.5, Id must come from a // constant instruction. auto targetEnv = spirv::getDefaultTargetEnv(getContext()); if (auto spirvModule = (*this)->getParentOfType()) targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule); if (targetEnv.getVersion() < spirv::Version::V_1_5) { auto *idOp = id().getDefiningOp(); if (!idOp || !isa(idOp)) // for spec constant return emitOpError("id must be the result of a constant op"); } return success(); } //===----------------------------------------------------------------------===// // spv.SubgroupBlockReadINTEL //===----------------------------------------------------------------------===// ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser, OperationState &state) { // Parse the storage class specification spirv::StorageClass storageClass; OpAsmParser::UnresolvedOperand ptrInfo; Type elementType; if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || parser.parseColon() || parser.parseType(elementType)) { return failure(); } auto ptrType = spirv::PointerType::get(elementType, storageClass); if (auto valVecTy = elementType.dyn_cast()) ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) { return failure(); } state.addTypes(elementType); return success(); } void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) { printer << " " << ptr() << " : " << getType(); } LogicalResult spirv::SubgroupBlockReadINTELOp::verify() { if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) return failure(); return success(); } //===----------------------------------------------------------------------===// // spv.SubgroupBlockWriteINTEL //===----------------------------------------------------------------------===// ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser, OperationState &state) { // Parse the storage class specification spirv::StorageClass storageClass; SmallVector operandInfo; auto loc = parser.getCurrentLocation(); Type elementType; if (parseEnumStrAttr(storageClass, parser) || parser.parseOperandList(operandInfo, 2) || parser.parseColon() || parser.parseType(elementType)) { return failure(); } auto ptrType = spirv::PointerType::get(elementType, storageClass); if (auto valVecTy = elementType.dyn_cast()) ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, state.operands)) { return failure(); } return success(); } void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) { printer << " " << ptr() << ", " << value() << " : " << value().getType(); } LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() { if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) return failure(); return success(); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformElectOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformElectOp::verify() { spirv::Scope scope = execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); return success(); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformFAddOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformFAddOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser, OperationState &result) { return parseGroupNonUniformArithmeticOp(parser, result); } void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) { printGroupNonUniformArithmeticOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformFMaxOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformFMaxOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser, OperationState &result) { return parseGroupNonUniformArithmeticOp(parser, result); } void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) { printGroupNonUniformArithmeticOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformFMinOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformFMinOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser, OperationState &result) { return parseGroupNonUniformArithmeticOp(parser, result); } void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) { printGroupNonUniformArithmeticOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformFMulOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformFMulOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser, OperationState &result) { return parseGroupNonUniformArithmeticOp(parser, result); } void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) { printGroupNonUniformArithmeticOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformIAddOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformIAddOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser, OperationState &result) { return parseGroupNonUniformArithmeticOp(parser, result); } void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) { printGroupNonUniformArithmeticOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformIMulOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformIMulOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser, OperationState &result) { return parseGroupNonUniformArithmeticOp(parser, result); } void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) { printGroupNonUniformArithmeticOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformSMaxOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformSMaxOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser, OperationState &result) { return parseGroupNonUniformArithmeticOp(parser, result); } void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) { printGroupNonUniformArithmeticOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformSMinOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformSMinOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser, OperationState &result) { return parseGroupNonUniformArithmeticOp(parser, result); } void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) { printGroupNonUniformArithmeticOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformUMaxOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformUMaxOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser, OperationState &result) { return parseGroupNonUniformArithmeticOp(parser, result); } void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) { printGroupNonUniformArithmeticOp(*this, p); } //===----------------------------------------------------------------------===// // spv.GroupNonUniformUMinOp //===----------------------------------------------------------------------===// LogicalResult spirv::GroupNonUniformUMinOp::verify() { return verifyGroupNonUniformArithmeticOp(*this); } ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser, OperationState &result) { return parseGroupNonUniformArithmeticOp(parser, result); } void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) { printGroupNonUniformArithmeticOp(*this, p); } //===----------------------------------------------------------------------===// // spv.ISubBorrowOp //===----------------------------------------------------------------------===// LogicalResult spirv::ISubBorrowOp::verify() { auto resultType = getType().cast(); if (resultType.getNumElements() != 2) return emitOpError("expected result struct type containing two members"); SmallVector types; types.push_back(operand1().getType()); types.push_back(operand2().getType()); types.push_back(resultType.getElementType(0)); types.push_back(resultType.getElementType(1)); if (!llvm::is_splat(types)) return emitOpError( "expected all operand types and struct member types are the same"); return success(); } ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser, OperationState &state) { SmallVector operands; if (parser.parseOptionalAttrDict(state.attributes) || parser.parseOperandList(operands) || parser.parseColon()) return failure(); Type resultType; auto loc = parser.getCurrentLocation(); if (parser.parseType(resultType)) return failure(); auto structType = resultType.dyn_cast(); if (!structType || structType.getNumElements() != 2) return parser.emitError(loc, "expected spv.struct type with two members"); SmallVector operandTypes(2, structType.getElementType(0)); if (parser.resolveOperands(operands, operandTypes, loc, state.operands)) return failure(); state.addTypes(resultType); return success(); } void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) { printer << ' '; printer.printOptionalAttrDict((*this)->getAttrs()); printer.printOperands((*this)->getOperands()); printer << " : " << getType(); } //===----------------------------------------------------------------------===// // spv.LoadOp //===----------------------------------------------------------------------===// void spirv::LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr, MemoryAccessAttr memoryAccess, IntegerAttr alignment) { auto ptrType = basePtr.getType().cast(); build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, alignment); } ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &state) { // Parse the storage class specification spirv::StorageClass storageClass; OpAsmParser::UnresolvedOperand ptrInfo; Type elementType; if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || parseMemoryAccessAttributes(parser, state) || parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() || parser.parseType(elementType)) { return failure(); } auto ptrType = spirv::PointerType::get(elementType, storageClass); if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) { return failure(); } state.addTypes(elementType); return success(); } void spirv::LoadOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( ptr().getType().cast().getStorageClass()); printer << " \"" << sc << "\" " << ptr(); printMemoryAccessAttribute(*this, printer, elidedAttrs); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); printer << " : " << getType(); } LogicalResult spirv::LoadOp::verify() { // SPIR-V spec : "Result Type is the type of the loaded object. It must be a // type with fixed size; i.e., it cannot be, nor include, any // OpTypeRuntimeArray types." if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) { return failure(); } return verifyMemoryAccessAttribute(*this); } //===----------------------------------------------------------------------===// // spv.mlir.loop //===----------------------------------------------------------------------===// void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) { state.addAttribute("loop_control", builder.getI32IntegerAttr( static_cast(spirv::LoopControl::None))); state.addRegion(); } ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &state) { if (parseControlAttribute(parser, state)) return failure(); return parser.parseRegion(*state.addRegion(), /*arguments=*/{}, /*argTypes=*/{}); } void spirv::LoopOp::print(OpAsmPrinter &printer) { auto control = loop_control(); if (control != spirv::LoopControl::None) printer << " control(" << spirv::stringifyLoopControl(control) << ")"; printer << ' '; printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } /// Returns true if the given `srcBlock` contains only one `spv.Branch` to the /// given `dstBlock`. static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { // Check that there is only one op in the `srcBlock`. if (!llvm::hasSingleElement(srcBlock)) return false; auto branchOp = dyn_cast(srcBlock.back()); return branchOp && branchOp.getSuccessor() == &dstBlock; } LogicalResult spirv::LoopOp::verifyRegions() { auto *op = getOperation(); // We need to verify that the blocks follow the following layout: // // +-------------+ // | entry block | // +-------------+ // | // v // +-------------+ // | loop header | <-----+ // +-------------+ | // | // ... | // \ | / | // v | // +---------------+ | // | loop continue | -----+ // +---------------+ // // ... // \ | / // v // +-------------+ // | merge block | // +-------------+ auto ®ion = op->getRegion(0); // Allow empty region as a degenerated case, which can come from // optimizations. if (region.empty()) return success(); // The last block is the merge block. Block &merge = region.back(); if (!isMergeBlock(merge)) return emitOpError( "last block must be the merge block with only one 'spv.mlir.merge' op"); if (std::next(region.begin()) == region.end()) return emitOpError( "must have an entry block branching to the loop header block"); // The first block is the entry block. Block &entry = region.front(); if (std::next(region.begin(), 2) == region.end()) return emitOpError( "must have a loop header block branched from the entry block"); // The second block is the loop header block. Block &header = *std::next(region.begin(), 1); if (!hasOneBranchOpTo(entry, header)) return emitOpError( "entry block must only have one 'spv.Branch' op to the second block"); if (std::next(region.begin(), 3) == region.end()) return emitOpError( "requires a loop continue block branching to the loop header block"); // The second to last block is the loop continue block. Block &cont = *std::prev(region.end(), 2); // Make sure that we have a branch from the loop continue block to the loop // header block. if (llvm::none_of( llvm::seq(0, cont.getNumSuccessors()), [&](unsigned index) { return cont.getSuccessor(index) == &header; })) return emitOpError("second to last block must be the loop continue " "block that branches to the loop header block"); // Make sure that no other blocks (except the entry and loop continue block) // branches to the loop header block. for (auto &block : llvm::make_range(std::next(region.begin(), 2), std::prev(region.end(), 2))) { for (auto i : llvm::seq(0, block.getNumSuccessors())) { if (block.getSuccessor(i) == &header) { return emitOpError("can only have the entry and loop continue " "block branching to the loop header block"); } } } return success(); } Block *spirv::LoopOp::getEntryBlock() { assert(!body().empty() && "op region should not be empty!"); return &body().front(); } Block *spirv::LoopOp::getHeaderBlock() { assert(!body().empty() && "op region should not be empty!"); // The second block is the loop header block. return &*std::next(body().begin()); } Block *spirv::LoopOp::getContinueBlock() { assert(!body().empty() && "op region should not be empty!"); // The second to last block is the loop continue block. return &*std::prev(body().end(), 2); } Block *spirv::LoopOp::getMergeBlock() { assert(!body().empty() && "op region should not be empty!"); // The last block is the loop merge block. return &body().back(); } void spirv::LoopOp::addEntryAndMergeBlock() { assert(body().empty() && "entry and merge block already exist"); body().push_back(new Block()); auto *mergeBlock = new Block(); body().push_back(mergeBlock); OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); // Add a spv.mlir.merge op into the merge block. builder.create(getLoc()); } //===----------------------------------------------------------------------===// // spv.MemoryBarrierOp //===----------------------------------------------------------------------===// LogicalResult spirv::MemoryBarrierOp::verify() { return verifyMemorySemantics(getOperation(), memory_semantics()); } //===----------------------------------------------------------------------===// // spv.mlir.merge //===----------------------------------------------------------------------===// LogicalResult spirv::MergeOp::verify() { auto *parentOp = (*this)->getParentOp(); if (!parentOp || !isa(parentOp)) return emitOpError( "expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'"); // TODO: This check should be done in `verifyRegions` of parent op. Block &parentLastBlock = (*this)->getParentRegion()->back(); if (getOperation() != parentLastBlock.getTerminator()) return emitOpError("can only be used in the last block of " "'spv.mlir.selection' or 'spv.mlir.loop'"); return success(); } //===----------------------------------------------------------------------===// // spv.module //===----------------------------------------------------------------------===// void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, Optional name) { OpBuilder::InsertionGuard guard(builder); builder.createBlock(state.addRegion()); if (name) { state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)); } } void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, spirv::AddressingModel addressingModel, spirv::MemoryModel memoryModel, Optional vceTriple, Optional name) { state.addAttribute( "addressing_model", builder.getI32IntegerAttr(static_cast(addressingModel))); state.addAttribute("memory_model", builder.getI32IntegerAttr( static_cast(memoryModel))); OpBuilder::InsertionGuard guard(builder); builder.createBlock(state.addRegion()); if (vceTriple) state.addAttribute(getVCETripleAttrName(), *vceTriple); if (name) state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), builder.getStringAttr(*name)); } ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, OperationState &state) { Region *body = state.addRegion(); // If the name is present, parse it. StringAttr nameAttr; (void)parser.parseOptionalSymbolName( nameAttr, mlir::SymbolTable::getSymbolAttrName(), state.attributes); // Parse attributes spirv::AddressingModel addrModel; spirv::MemoryModel memoryModel; if (::parseEnumKeywordAttr(addrModel, parser, state) || ::parseEnumKeywordAttr(memoryModel, parser, state)) return failure(); if (succeeded(parser.parseOptionalKeyword("requires"))) { spirv::VerCapExtAttr vceTriple; if (parser.parseAttribute(vceTriple, spirv::ModuleOp::getVCETripleAttrName(), state.attributes)) return failure(); } if (parser.parseOptionalAttrDictWithKeyword(state.attributes) || parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{})) return failure(); // Make sure we have at least one block. if (body->empty()) body->push_back(new Block()); return success(); } void spirv::ModuleOp::print(OpAsmPrinter &printer) { if (Optional name = getName()) { printer << ' '; printer.printSymbolName(*name); } SmallVector elidedAttrs; printer << " " << spirv::stringifyAddressingModel(addressing_model()) << " " << spirv::stringifyMemoryModel(memory_model()); auto addressingModelAttrName = spirv::attributeName(); auto memoryModelAttrName = spirv::attributeName(); elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName, mlir::SymbolTable::getSymbolAttrName()}); if (Optional triple = vce_triple()) { printer << " requires " << *triple; elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName()); } printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs); printer << ' '; printer.printRegion(getRegion()); } LogicalResult spirv::ModuleOp::verifyRegions() { Dialect *dialect = (*this)->getDialect(); DenseMap, spirv::EntryPointOp> entryPoints; mlir::SymbolTable table(*this); for (auto &op : *getBody()) { if (op.getDialect() != dialect) return op.emitError("'spv.module' can only contain spv.* ops"); // For EntryPoint op, check that the function and execution model is not // duplicated in EntryPointOps. Also verify that the interface specified // comes from globalVariables here to make this check cheaper. if (auto entryPointOp = dyn_cast(op)) { auto funcOp = table.lookup(entryPointOp.fn()); if (!funcOp) { return entryPointOp.emitError("function '") << entryPointOp.fn() << "' not found in 'spv.module'"; } if (auto interface = entryPointOp.interface()) { for (Attribute varRef : interface) { auto varSymRef = varRef.dyn_cast(); if (!varSymRef) { return entryPointOp.emitError( "expected symbol reference for interface " "specification instead of '") << varRef; } auto variableOp = table.lookup(varSymRef.getValue()); if (!variableOp) { return entryPointOp.emitError("expected spv.GlobalVariable " "symbol reference instead of'") << varSymRef << "'"; } } } auto key = std::pair( funcOp, entryPointOp.execution_model()); auto entryPtIt = entryPoints.find(key); if (entryPtIt != entryPoints.end()) { return entryPointOp.emitError("duplicate of a previous EntryPointOp"); } entryPoints[key] = entryPointOp; } else if (auto funcOp = dyn_cast(op)) { if (funcOp.isExternal()) return op.emitError("'spv.module' cannot contain external functions"); // TODO: move this check to spv.func. for (auto &block : funcOp) for (auto &op : block) { if (op.getDialect() != dialect) return op.emitError( "functions in 'spv.module' can only contain spv.* ops"); } } } return success(); } //===----------------------------------------------------------------------===// // spv.mlir.referenceof //===----------------------------------------------------------------------===// LogicalResult spirv::ReferenceOfOp::verify() { auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( (*this)->getParentOp(), spec_constAttr()); Type constType; auto specConstOp = dyn_cast_or_null(specConstSym); if (specConstOp) constType = specConstOp.default_value().getType(); auto specConstCompositeOp = dyn_cast_or_null(specConstSym); if (specConstCompositeOp) constType = specConstCompositeOp.type(); if (!specConstOp && !specConstCompositeOp) return emitOpError( "expected spv.SpecConstant or spv.SpecConstantComposite symbol"); if (reference().getType() != constType) return emitOpError("result type mismatch with the referenced " "specialization constant's type"); return success(); } //===----------------------------------------------------------------------===// // spv.Return //===----------------------------------------------------------------------===// LogicalResult spirv::ReturnOp::verify() { // Verification is performed in spv.func op. return success(); } //===----------------------------------------------------------------------===// // spv.ReturnValue //===----------------------------------------------------------------------===// LogicalResult spirv::ReturnValueOp::verify() { // Verification is performed in spv.func op. return success(); } //===----------------------------------------------------------------------===// // spv.Select //===----------------------------------------------------------------------===// LogicalResult spirv::SelectOp::verify() { if (auto conditionTy = condition().getType().dyn_cast()) { auto resultVectorTy = result().getType().dyn_cast(); if (!resultVectorTy) { return emitOpError("result expected to be of vector type when " "condition is of vector type"); } if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) { return emitOpError("result should have the same number of elements as " "the condition when condition is of vector type"); } } return success(); } //===----------------------------------------------------------------------===// // spv.mlir.selection //===----------------------------------------------------------------------===// ParseResult spirv::SelectionOp::parse(OpAsmParser &parser, OperationState &state) { if (parseControlAttribute(parser, state)) return failure(); return parser.parseRegion(*state.addRegion(), /*arguments=*/{}, /*argTypes=*/{}); } void spirv::SelectionOp::print(OpAsmPrinter &printer) { auto control = selection_control(); if (control != spirv::SelectionControl::None) printer << " control(" << spirv::stringifySelectionControl(control) << ")"; printer << ' '; printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); } LogicalResult spirv::SelectionOp::verifyRegions() { auto *op = getOperation(); // We need to verify that the blocks follow the following layout: // // +--------------+ // | header block | // +--------------+ // / | \ // ... // // // +---------+ +---------+ +---------+ // | case #0 | | case #1 | | case #2 | ... // +---------+ +---------+ +---------+ // // // ... // \ | / // v // +-------------+ // | merge block | // +-------------+ auto ®ion = op->getRegion(0); // Allow empty region as a degenerated case, which can come from // optimizations. if (region.empty()) return success(); // The last block is the merge block. if (!isMergeBlock(region.back())) return emitOpError( "last block must be the merge block with only one 'spv.mlir.merge' op"); if (std::next(region.begin()) == region.end()) return emitOpError("must have a selection header block"); return success(); } Block *spirv::SelectionOp::getHeaderBlock() { assert(!body().empty() && "op region should not be empty!"); // The first block is the loop header block. return &body().front(); } Block *spirv::SelectionOp::getMergeBlock() { assert(!body().empty() && "op region should not be empty!"); // The last block is the loop merge block. return &body().back(); } void spirv::SelectionOp::addMergeBlock() { assert(body().empty() && "entry and merge block already exist"); auto *mergeBlock = new Block(); body().push_back(mergeBlock); OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); // Add a spv.mlir.merge op into the merge block. builder.create(getLoc()); } spirv::SelectionOp spirv::SelectionOp::createIfThen( Location loc, Value condition, function_ref thenBody, OpBuilder &builder) { auto selectionOp = builder.create(loc, spirv::SelectionControl::None); selectionOp.addMergeBlock(); Block *mergeBlock = selectionOp.getMergeBlock(); Block *thenBlock = nullptr; // Build the "then" block. { OpBuilder::InsertionGuard guard(builder); thenBlock = builder.createBlock(mergeBlock); thenBody(builder); builder.create(loc, mergeBlock); } // Build the header block. { OpBuilder::InsertionGuard guard(builder); builder.createBlock(thenBlock); builder.create( loc, condition, thenBlock, /*trueArguments=*/ArrayRef(), mergeBlock, /*falseArguments=*/ArrayRef()); } return selectionOp; } //===----------------------------------------------------------------------===// // spv.SpecConstant //===----------------------------------------------------------------------===// ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser, OperationState &state) { StringAttr nameAttr; Attribute valueAttr; if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), state.attributes)) return failure(); // Parse optional spec_id. if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) { IntegerAttr specIdAttr; if (parser.parseLParen() || parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) || parser.parseRParen()) return failure(); } if (parser.parseEqual() || parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes)) return failure(); return success(); } void spirv::SpecConstantOp::print(OpAsmPrinter &printer) { printer << ' '; printer.printSymbolName(sym_name()); if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName)) printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; printer << " = " << default_value(); } LogicalResult spirv::SpecConstantOp::verify() { if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName)) if (specID.getValue().isNegative()) return emitOpError("SpecId cannot be negative"); auto value = default_value(); if (value.isa()) { // Make sure bitwidth is allowed. if (!value.getType().isa()) return emitOpError("default value bitwidth disallowed"); return success(); } return emitOpError( "default value can only be a bool, integer, or float scalar"); } //===----------------------------------------------------------------------===// // spv.StoreOp //===----------------------------------------------------------------------===// ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &state) { // Parse the storage class specification spirv::StorageClass storageClass; SmallVector operandInfo; auto loc = parser.getCurrentLocation(); Type elementType; if (parseEnumStrAttr(storageClass, parser) || parser.parseOperandList(operandInfo, 2) || parseMemoryAccessAttributes(parser, state) || parser.parseColon() || parser.parseType(elementType)) { return failure(); } auto ptrType = spirv::PointerType::get(elementType, storageClass); if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, state.operands)) { return failure(); } return success(); } void spirv::StoreOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs; StringRef sc = stringifyStorageClass( ptr().getType().cast().getStorageClass()); printer << " \"" << sc << "\" " << ptr() << ", " << value(); printMemoryAccessAttribute(*this, printer, elidedAttrs); printer << " : " << value().getType(); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); } LogicalResult spirv::StoreOp::verify() { // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an // OpTypePointer whose Type operand is the same as the type of Object." if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) return failure(); return verifyMemoryAccessAttribute(*this); } //===----------------------------------------------------------------------===// // spv.Unreachable //===----------------------------------------------------------------------===// LogicalResult spirv::UnreachableOp::verify() { auto *block = (*this)->getBlock(); // Fast track: if this is in entry block, its invalid. Otherwise, if no // predecessors, it's valid. if (block->isEntryBlock()) return emitOpError("cannot be used in reachable block"); if (block->hasNoPredecessors()) return success(); // TODO: further verification needs to analyze reachability from // the entry block. return success(); } //===----------------------------------------------------------------------===// // spv.Variable //===----------------------------------------------------------------------===// ParseResult spirv::VariableOp::parse(OpAsmParser &parser, OperationState &state) { // Parse optional initializer Optional initInfo; if (succeeded(parser.parseOptionalKeyword("init"))) { initInfo = OpAsmParser::UnresolvedOperand(); if (parser.parseLParen() || parser.parseOperand(*initInfo) || parser.parseRParen()) return failure(); } if (parseVariableDecorations(parser, state)) { return failure(); } // Parse result pointer type Type type; if (parser.parseColon()) return failure(); auto loc = parser.getCurrentLocation(); if (parser.parseType(type)) return failure(); auto ptrType = type.dyn_cast(); if (!ptrType) return parser.emitError(loc, "expected spv.ptr type"); state.addTypes(ptrType); // Resolve the initializer operand if (initInfo) { if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), state.operands)) return failure(); } auto attr = parser.getBuilder().getI32IntegerAttr( llvm::bit_cast(ptrType.getStorageClass())); state.addAttribute(spirv::attributeName(), attr); return success(); } void spirv::VariableOp::print(OpAsmPrinter &printer) { SmallVector elidedAttrs{ spirv::attributeName()}; // Print optional initializer if (getNumOperands() != 0) printer << " init(" << initializer() << ")"; printVariableDecorations(*this, printer, elidedAttrs); printer << " : " << getType(); } LogicalResult spirv::VariableOp::verify() { // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the // object. It cannot be Generic. It must be the same as the Storage Class // operand of the Result Type." if (storage_class() != spirv::StorageClass::Function) { return emitOpError( "can only be used to model function-level variables. Use " "spv.GlobalVariable for module-level variables."); } auto pointerType = pointer().getType().cast(); if (storage_class() != pointerType.getStorageClass()) return emitOpError( "storage class must match result pointer's storage class"); if (getNumOperands() != 0) { // SPIR-V spec: "Initializer must be an from a constant instruction or // a global (module scope) OpVariable instruction". auto *initOp = getOperand(0).getDefiningOp(); if (!initOp || !isa(initOp)) return emitOpError("initializer must be the result of a " "constant or spv.GlobalVariable op"); } // TODO: generate these strings using ODS. auto *op = getOperation(); auto descriptorSetName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::DescriptorSet)); auto bindingName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::Binding)); auto builtInName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::BuiltIn)); for (const auto &attr : {descriptorSetName, bindingName, builtInName}) { if (op->getAttr(attr)) return emitOpError("cannot have '") << attr << "' attribute (only allowed in spv.GlobalVariable)"; } return success(); } //===----------------------------------------------------------------------===// // spv.VectorShuffle //===----------------------------------------------------------------------===// LogicalResult spirv::VectorShuffleOp::verify() { VectorType resultType = getType().cast(); size_t numResultElements = resultType.getNumElements(); if (numResultElements != components().size()) return emitOpError("result type element count (") << numResultElements << ") mismatch with the number of component selectors (" << components().size() << ")"; size_t totalSrcElements = vector1().getType().cast().getNumElements() + vector2().getType().cast().getNumElements(); for (const auto &selector : components().getAsValueRange()) { uint32_t index = selector.getZExtValue(); if (index >= totalSrcElements && index != std::numeric_limits().max()) return emitOpError("component selector ") << index << " out of range: expected to be in [0, " << totalSrcElements << ") or 0xffffffff"; } return success(); } //===----------------------------------------------------------------------===// // spv.CooperativeMatrixLoadNV //===----------------------------------------------------------------------===// ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser, OperationState &state) { SmallVector operandInfo; Type strideType = parser.getBuilder().getIntegerType(32); Type columnMajorType = parser.getBuilder().getIntegerType(1); Type ptrType; Type elementType; if (parser.parseOperandList(operandInfo, 3) || parseMemoryAccessAttributes(parser, state) || parser.parseColon() || parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) { return failure(); } if (parser.resolveOperands(operandInfo, {ptrType, strideType, columnMajorType}, parser.getNameLoc(), state.operands)) { return failure(); } state.addTypes(elementType); return success(); } void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) { printer << " " << pointer() << ", " << stride() << ", " << columnmajor(); // Print optional memory access attribute. if (auto memAccess = memory_access()) printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; printer << " : " << pointer().getType() << " as " << getType(); } static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, Type coopMatrix) { Type pointeeType = pointer.cast().getPointeeType(); if (!pointeeType.isa() && !pointeeType.isa()) return op->emitError( "Pointer must point to a scalar or vector type but provided ") << pointeeType; spirv::StorageClass storage = pointer.cast().getStorageClass(); if (storage != spirv::StorageClass::Workgroup && storage != spirv::StorageClass::StorageBuffer && storage != spirv::StorageClass::PhysicalStorageBuffer) return op->emitError( "Pointer storage class must be Workgroup, StorageBuffer or " "PhysicalStorageBufferEXT but provided ") << stringifyStorageClass(storage); return success(); } LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() { return verifyPointerAndCoopMatrixType(*this, pointer().getType(), result().getType()); } //===----------------------------------------------------------------------===// // spv.CooperativeMatrixStoreNV //===----------------------------------------------------------------------===// ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser, OperationState &state) { SmallVector operandInfo; Type strideType = parser.getBuilder().getIntegerType(32); Type columnMajorType = parser.getBuilder().getIntegerType(1); Type ptrType; Type elementType; if (parser.parseOperandList(operandInfo, 4) || parseMemoryAccessAttributes(parser, state) || parser.parseColon() || parser.parseType(ptrType) || parser.parseComma() || parser.parseType(elementType)) { return failure(); } if (parser.resolveOperands( operandInfo, {ptrType, elementType, strideType, columnMajorType}, parser.getNameLoc(), state.operands)) { return failure(); } return success(); } void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) { printer << " " << pointer() << ", " << object() << ", " << stride() << ", " << columnmajor(); // Print optional memory access attribute. if (auto memAccess = memory_access()) printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; printer << " : " << pointer().getType() << ", " << getOperand(1).getType(); } LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() { return verifyPointerAndCoopMatrixType(*this, pointer().getType(), object().getType()); } //===----------------------------------------------------------------------===// // spv.CooperativeMatrixMulAddNV //===----------------------------------------------------------------------===// static LogicalResult verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) { if (op.c().getType() != op.result().getType()) return op.emitOpError("result and third operand must have the same type"); auto typeA = op.a().getType().cast(); auto typeB = op.b().getType().cast(); auto typeC = op.c().getType().cast(); auto typeR = op.result().getType().cast(); if (typeA.getRows() != typeR.getRows() || typeA.getColumns() != typeB.getRows() || typeB.getColumns() != typeR.getColumns()) return op.emitOpError("matrix size must match"); if (typeR.getScope() != typeA.getScope() || typeR.getScope() != typeB.getScope() || typeR.getScope() != typeC.getScope()) return op.emitOpError("matrix scope must match"); if (typeA.getElementType() != typeB.getElementType() || typeR.getElementType() != typeC.getElementType()) return op.emitOpError("matrix element type must match"); return success(); } LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() { return verifyCoopMatrixMulAdd(*this); } //===----------------------------------------------------------------------===// // spv.MatrixTimesScalar //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesScalarOp::verify() { // We already checked that result and matrix are both of matrix type in the // auto-generated verify method. auto inputMatrix = matrix().getType().cast(); auto resultMatrix = result().getType().cast(); // Check that the scalar type is the same as the matrix element type. if (scalar().getType() != inputMatrix.getElementType()) return emitError("input matrix components' type and scaling value must " "have the same type"); // Note that the next three checks could be done using the AllTypesMatch // trait in the Op definition file but it generates a vague error message. // Check that the input and result matrices have the same columns' count if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns()) return emitError("input and result matrices must have the same " "number of columns"); // Check that the input and result matrices' have the same rows count if (inputMatrix.getNumRows() != resultMatrix.getNumRows()) return emitError("input and result matrices' columns must have " "the same size"); // Check that the input and result matrices' have the same component type if (inputMatrix.getElementType() != resultMatrix.getElementType()) return emitError("input and result matrices' columns must have " "the same component type"); return success(); } //===----------------------------------------------------------------------===// // spv.CopyMemory //===----------------------------------------------------------------------===// void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) { printer << ' '; StringRef targetStorageClass = stringifyStorageClass( target().getType().cast().getStorageClass()); printer << " \"" << targetStorageClass << "\" " << target() << ", "; StringRef sourceStorageClass = stringifyStorageClass( source().getType().cast().getStorageClass()); printer << " \"" << sourceStorageClass << "\" " << source(); SmallVector elidedAttrs; printMemoryAccessAttribute(*this, printer, elidedAttrs); printSourceMemoryAccessAttribute(*this, printer, elidedAttrs, source_memory_access(), source_alignment()); printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); Type pointeeType = target().getType().cast().getPointeeType(); printer << " : " << pointeeType; } ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser, OperationState &state) { spirv::StorageClass targetStorageClass; OpAsmParser::UnresolvedOperand targetPtrInfo; spirv::StorageClass sourceStorageClass; OpAsmParser::UnresolvedOperand sourcePtrInfo; Type elementType; if (parseEnumStrAttr(targetStorageClass, parser) || parser.parseOperand(targetPtrInfo) || parser.parseComma() || parseEnumStrAttr(sourceStorageClass, parser) || parser.parseOperand(sourcePtrInfo) || parseMemoryAccessAttributes(parser, state)) { return failure(); } if (!parser.parseOptionalComma()) { // Parse 2nd memory access attributes. if (parseSourceMemoryAccessAttributes(parser, state)) { return failure(); } } if (parser.parseColon() || parser.parseType(elementType)) return failure(); if (parser.parseOptionalAttrDict(state.attributes)) return failure(); auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) || parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) { return failure(); } return success(); } LogicalResult spirv::CopyMemoryOp::verify() { Type targetType = target().getType().cast().getPointeeType(); Type sourceType = source().getType().cast().getPointeeType(); if (targetType != sourceType) return emitOpError("both operands must be pointers to the same type"); if (failed(verifyMemoryAccessAttribute(*this))) return failure(); // TODO - According to the spec: // // If two masks are present, the first applies to Target and cannot include // MakePointerVisible, and the second applies to Source and cannot include // MakePointerAvailable. // // Add such verification here. return verifySourceMemoryAccessAttribute(*this); } //===----------------------------------------------------------------------===// // spv.Transpose //===----------------------------------------------------------------------===// LogicalResult spirv::TransposeOp::verify() { auto inputMatrix = matrix().getType().cast(); auto resultMatrix = result().getType().cast(); // Verify that the input and output matrices have correct shapes. if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) return emitError("input matrix rows count must be equal to " "output matrix columns count"); if (inputMatrix.getNumColumns() != resultMatrix.getNumRows()) return emitError("input matrix columns count must be equal to " "output matrix rows count"); // Verify that the input and output matrices have the same component type if (inputMatrix.getElementType() != resultMatrix.getElementType()) return emitError("input and output matrices must have the same " "component type"); return success(); } //===----------------------------------------------------------------------===// // spv.MatrixTimesMatrix //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesMatrixOp::verify() { auto leftMatrix = leftmatrix().getType().cast(); auto rightMatrix = rightmatrix().getType().cast(); auto resultMatrix = result().getType().cast(); // left matrix columns' count and right matrix rows' count must be equal if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) return emitError("left matrix columns' count must be equal to " "the right matrix rows' count"); // right and result matrices columns' count must be the same if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns()) return emitError( "right and result matrices must have equal columns' count"); // right and result matrices component type must be the same if (rightMatrix.getElementType() != resultMatrix.getElementType()) return emitError("right and result matrices' component type must" " be the same"); // left and result matrices component type must be the same if (leftMatrix.getElementType() != resultMatrix.getElementType()) return emitError("left and result matrices' component type" " must be the same"); // left and result matrices rows count must be the same if (leftMatrix.getNumRows() != resultMatrix.getNumRows()) return emitError("left and result matrices must have equal rows' count"); return success(); } //===----------------------------------------------------------------------===// // spv.SpecConstantComposite //===----------------------------------------------------------------------===// ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser, OperationState &state) { StringAttr compositeName; if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(), state.attributes)) return failure(); if (parser.parseLParen()) return failure(); SmallVector constituents; do { // The name of the constituent attribute isn't important const char *attrName = "spec_const"; FlatSymbolRefAttr specConstRef; NamedAttrList attrs; if (parser.parseAttribute(specConstRef, Type(), attrName, attrs)) return failure(); constituents.push_back(specConstRef); } while (!parser.parseOptionalComma()); if (parser.parseRParen()) return failure(); state.addAttribute(kCompositeSpecConstituentsName, parser.getBuilder().getArrayAttr(constituents)); Type type; if (parser.parseColonType(type)) return failure(); state.addAttribute(kTypeAttrName, TypeAttr::get(type)); return success(); } void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { printer << " "; printer.printSymbolName(sym_name()); printer << " ("; auto constituents = this->constituents().getValue(); if (!constituents.empty()) llvm::interleaveComma(constituents, printer); printer << ") : " << type(); } LogicalResult spirv::SpecConstantCompositeOp::verify() { auto cType = type().dyn_cast(); auto constituents = this->constituents().getValue(); if (!cType) return emitError("result type must be a composite type, but provided ") << type(); if (cType.isa()) return emitError("unsupported composite type ") << cType; if (constituents.size() != cType.getNumElements()) return emitError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << constituents.size(); for (auto index : llvm::seq(0, constituents.size())) { auto constituent = constituents[index].cast(); auto constituentSpecConstOp = dyn_cast(SymbolTable::lookupNearestSymbolFrom( (*this)->getParentOp(), constituent.getAttr())); if (constituentSpecConstOp.default_value().getType() != cType.getElementType(index)) return emitError("has incorrect types of operands: expected ") << cType.getElementType(index) << ", but provided " << constituentSpecConstOp.default_value().getType(); } return success(); } //===----------------------------------------------------------------------===// // spv.SpecConstantOperation //===----------------------------------------------------------------------===// ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser, OperationState &state) { Region *body = state.addRegion(); if (parser.parseKeyword("wraps")) return failure(); body->push_back(new Block); Block &block = body->back(); Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); if (!wrappedOp) return failure(); OpBuilder builder(parser.getContext()); builder.setInsertionPointToEnd(&block); builder.create(wrappedOp->getLoc(), wrappedOp->getResult(0)); state.location = wrappedOp->getLoc(); state.addTypes(wrappedOp->getResult(0).getType()); if (parser.parseOptionalAttrDict(state.attributes)) return failure(); return success(); } void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) { printer << " wraps "; printer.printGenericOp(&body().front().front()); } LogicalResult spirv::SpecConstantOperationOp::verifyRegions() { Block &block = getRegion().getBlocks().front(); if (block.getOperations().size() != 2) return emitOpError("expected exactly 2 nested ops"); Operation &enclosedOp = block.getOperations().front(); if (!enclosedOp.hasTrait()) return emitOpError("invalid enclosed op"); for (auto operand : enclosedOp.getOperands()) if (!isa(operand.getDefiningOp())) return emitOpError( "invalid operand, must be defined by a constant operation"); return success(); } //===----------------------------------------------------------------------===// // spv.GL.FrexpStruct //===----------------------------------------------------------------------===// LogicalResult spirv::GLFrexpStructOp::verify() { spirv::StructType structTy = result().getType().dyn_cast(); if (structTy.getNumElements() != 2) return emitError("result type must be a struct type with two memebers"); Type significandTy = structTy.getElementType(0); Type exponentTy = structTy.getElementType(1); VectorType exponentVecTy = exponentTy.dyn_cast(); IntegerType exponentIntTy = exponentTy.dyn_cast(); Type operandTy = operand().getType(); VectorType operandVecTy = operandTy.dyn_cast(); FloatType operandFTy = operandTy.dyn_cast(); if (significandTy != operandTy) return emitError("member zero of the resulting struct type must be the " "same type as the operand"); if (exponentVecTy) { IntegerType componentIntTy = exponentVecTy.getElementType().dyn_cast(); if (!componentIntTy || componentIntTy.getWidth() != 32) return emitError("member one of the resulting struct type must" "be a scalar or vector of 32 bit integer type"); } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) { return emitError("member one of the resulting struct type " "must be a scalar or vector of 32 bit integer type"); } // Check that the two member types have the same number of components if (operandVecTy && exponentVecTy && (exponentVecTy.getNumElements() == operandVecTy.getNumElements())) return success(); if (operandFTy && exponentIntTy) return success(); return emitError("member one of the resulting struct type must have the same " "number of components as the operand type"); } //===----------------------------------------------------------------------===// // spv.GL.Ldexp //===----------------------------------------------------------------------===// LogicalResult spirv::GLLdexpOp::verify() { Type significandType = x().getType(); Type exponentType = exp().getType(); if (significandType.isa() != exponentType.isa()) return emitOpError("operands must both be scalars or vectors"); auto getNumElements = [](Type type) -> unsigned { if (auto vectorType = type.dyn_cast()) return vectorType.getNumElements(); return 1; }; if (getNumElements(significandType) != getNumElements(exponentType)) return emitOpError("operands must have the same number of elements"); return success(); } //===----------------------------------------------------------------------===// // spv.ImageDrefGather //===----------------------------------------------------------------------===// LogicalResult spirv::ImageDrefGatherOp::verify() { VectorType resultType = result().getType().cast(); auto sampledImageType = sampledimage().getType().cast(); auto imageType = sampledImageType.getImageType().cast(); if (resultType.getNumElements() != 4) return emitOpError("result type must be a vector of four components"); Type elementType = resultType.getElementType(); Type sampledElementType = imageType.getElementType(); if (!sampledElementType.isa() && elementType != sampledElementType) return emitOpError( "the component type of result must be the same as sampled type of the " "underlying image type"); spirv::Dim imageDim = imageType.getDim(); spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo(); if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube && imageDim != spirv::Dim::Rect) return emitOpError( "the Dim operand of the underlying image type must be 2D, Cube, or " "Rect"); if (imageMS != spirv::ImageSamplingInfo::SingleSampled) return emitOpError("the MS operand of the underlying image type must be 0"); spirv::ImageOperandsAttr attr = imageoperandsAttr(); auto operandArguments = operand_arguments(); return verifyImageOperands(*this, attr, operandArguments); } //===----------------------------------------------------------------------===// // spv.ShiftLeftLogicalOp //===----------------------------------------------------------------------===// LogicalResult spirv::ShiftLeftLogicalOp::verify() { return verifyShiftOp(*this); } //===----------------------------------------------------------------------===// // spv.ShiftRightArithmeticOp //===----------------------------------------------------------------------===// LogicalResult spirv::ShiftRightArithmeticOp::verify() { return verifyShiftOp(*this); } //===----------------------------------------------------------------------===// // spv.ShiftRightLogicalOp //===----------------------------------------------------------------------===// LogicalResult spirv::ShiftRightLogicalOp::verify() { return verifyShiftOp(*this); } //===----------------------------------------------------------------------===// // spv.ImageQuerySize //===----------------------------------------------------------------------===// LogicalResult spirv::ImageQuerySizeOp::verify() { spirv::ImageType imageType = image().getType().cast(); Type resultType = result().getType(); spirv::Dim dim = imageType.getDim(); spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo(); spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo(); switch (dim) { case spirv::Dim::Dim1D: case spirv::Dim::Dim2D: case spirv::Dim::Dim3D: case spirv::Dim::Cube: if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled && samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown && samplerInfo != spirv::ImageSamplerUseInfo::NoSampler) return emitError( "if Dim is 1D, 2D, 3D, or Cube, " "it must also have either an MS of 1 or a Sampled of 0 or 2"); break; case spirv::Dim::Buffer: case spirv::Dim::Rect: break; default: return emitError("the Dim operand of the image type must " "be 1D, 2D, 3D, Buffer, Cube, or Rect"); } unsigned componentNumber = 0; switch (dim) { case spirv::Dim::Dim1D: case spirv::Dim::Buffer: componentNumber = 1; break; case spirv::Dim::Dim2D: case spirv::Dim::Cube: case spirv::Dim::Rect: componentNumber = 2; break; case spirv::Dim::Dim3D: componentNumber = 3; break; default: break; } if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed) componentNumber += 1; unsigned resultComponentNumber = 1; if (auto resultVectorType = resultType.dyn_cast()) resultComponentNumber = resultVectorType.getNumElements(); if (componentNumber != resultComponentNumber) return emitError("expected the result to have ") << componentNumber << " component(s), but found " << resultComponentNumber << " component(s)"; return success(); } static ParseResult parsePtrAccessChainOpImpl(StringRef opName, OpAsmParser &parser, OperationState &state) { OpAsmParser::UnresolvedOperand ptrInfo; SmallVector indicesInfo; Type type; auto loc = parser.getCurrentLocation(); SmallVector indicesTypes; if (parser.parseOperand(ptrInfo) || parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || parser.parseColonType(type) || parser.resolveOperand(ptrInfo, type, state.operands)) return failure(); // Check that the provided indices list is not empty before parsing their // type list. if (indicesInfo.empty()) return emitError(state.location) << opName << " expected element"; if (parser.parseComma() || parser.parseTypeList(indicesTypes)) return failure(); // Check that the indices types list is not empty and that it has a one-to-one // mapping to the provided indices. if (indicesTypes.size() != indicesInfo.size()) return emitError(state.location) << opName << " indices types' count must be equal to indices info count"; if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) return failure(); auto resultType = getElementPtrType( type, llvm::makeArrayRef(state.operands).drop_front(2), state.location); if (!resultType) return failure(); state.addTypes(resultType); return success(); } template static auto concatElemAndIndices(Op op) { SmallVector ret(op.indices().size() + 1); ret[0] = op.element(); llvm::copy(op.indices(), ret.begin() + 1); return ret; } //===----------------------------------------------------------------------===// // spv.InBoundsPtrAccessChainOp //===----------------------------------------------------------------------===// void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state, Value basePtr, Value element, ValueRange indices) { auto type = getElementPtrType(basePtr.getType(), indices, state.location); assert(type && "Unable to deduce return type based on basePtr and indices"); build(builder, state, type, basePtr, element, indices); } ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser, OperationState &state) { return parsePtrAccessChainOpImpl( spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state); } void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) { printAccessChain(*this, concatElemAndIndices(*this), printer); } LogicalResult spirv::InBoundsPtrAccessChainOp::verify() { return verifyAccessChain(*this, indices()); } //===----------------------------------------------------------------------===// // spv.PtrAccessChainOp //===----------------------------------------------------------------------===// void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state, Value basePtr, Value element, ValueRange indices) { auto type = getElementPtrType(basePtr.getType(), indices, state.location); assert(type && "Unable to deduce return type based on basePtr and indices"); build(builder, state, type, basePtr, element, indices); } ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser, OperationState &state) { return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(), parser, state); } void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) { printAccessChain(*this, concatElemAndIndices(*this), printer); } LogicalResult spirv::PtrAccessChainOp::verify() { return verifyAccessChain(*this, indices()); } //===----------------------------------------------------------------------===// // spv.VectorTimesScalarOp //===----------------------------------------------------------------------===// LogicalResult spirv::VectorTimesScalarOp::verify() { if (vector().getType() != getType()) return emitOpError("vector operand and result type mismatch"); auto scalarType = getType().cast().getElementType(); if (scalar().getType() != scalarType) return emitOpError("scalar operand and result element type match"); return success(); } // TableGen'erated operation interfaces for querying versions, extensions, and // capabilities. #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc" // TablenGen'erated operation definitions. #define GET_OP_CLASSES #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" namespace mlir { namespace spirv { // TableGen'erated operation availability interface implementations. #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc" } // namespace spirv } // namespace mlir