1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
17 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser/Parser.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/raw_ostream.h"
32
33 using namespace mlir;
34 using namespace mlir::spirv;
35
36 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
37
38 //===----------------------------------------------------------------------===//
39 // InlinerInterface
40 //===----------------------------------------------------------------------===//
41
42 /// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
containsReturn(Region & region)43 static inline bool containsReturn(Region ®ion) {
44 return llvm::any_of(region, [](Block &block) {
45 Operation *terminator = block.getTerminator();
46 return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
47 });
48 }
49
50 namespace {
51 /// This class defines the interface for inlining within the SPIR-V dialect.
52 struct SPIRVInlinerInterface : public DialectInlinerInterface {
53 using DialectInlinerInterface::DialectInlinerInterface;
54
55 /// All call operations within SPIRV can be inlined.
isLegalToInline__anon52aeee350211::SPIRVInlinerInterface56 bool isLegalToInline(Operation *call, Operation *callable,
57 bool wouldBeCloned) const final {
58 return true;
59 }
60
61 /// Returns true if the given region 'src' can be inlined into the region
62 /// 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anon52aeee350211::SPIRVInlinerInterface63 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64 BlockAndValueMapping &) const final {
65 // Return true here when inlining into spv.func, spv.mlir.selection, and
66 // spv.mlir.loop operations.
67 auto *op = dest->getParentOp();
68 return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69 }
70
71 /// Returns true if the given operation 'op', that is registered to this
72 /// dialect, can be inlined into the region 'dest' that is attached to an
73 /// operation registered to the current dialect.
isLegalToInline__anon52aeee350211::SPIRVInlinerInterface74 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75 BlockAndValueMapping &) const final {
76 // TODO: Enable inlining structured control flows with return.
77 if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
78 containsReturn(op->getRegion(0)))
79 return false;
80 // TODO: we need to filter OpKill here to avoid inlining it to
81 // a loop continue construct:
82 // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83 // However OpKill is fragment shader specific and we don't support it yet.
84 return true;
85 }
86
87 /// Handle the given inlined terminator by replacing it with a new operation
88 /// as necessary.
handleTerminator__anon52aeee350211::SPIRVInlinerInterface89 void handleTerminator(Operation *op, Block *newDest) const final {
90 if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
91 OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
92 op->erase();
93 } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
94 llvm_unreachable("unimplemented spv.ReturnValue in inliner");
95 }
96 }
97
98 /// Handle the given inlined terminator by replacing it with a new operation
99 /// as necessary.
handleTerminator__anon52aeee350211::SPIRVInlinerInterface100 void handleTerminator(Operation *op,
101 ArrayRef<Value> valuesToRepl) const final {
102 // Only spv.ReturnValue needs to be handled here.
103 auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
104 if (!retValOp)
105 return;
106
107 // Replace the values directly with the return operands.
108 assert(valuesToRepl.size() == 1 &&
109 "spv.ReturnValue expected to only handle one result");
110 valuesToRepl.front().replaceAllUsesWith(retValOp.value());
111 }
112 };
113 } // namespace
114
115 //===----------------------------------------------------------------------===//
116 // SPIR-V Dialect
117 //===----------------------------------------------------------------------===//
118
initialize()119 void SPIRVDialect::initialize() {
120 registerAttributes();
121 registerTypes();
122
123 // Add SPIR-V ops.
124 addOperations<
125 #define GET_OP_LIST
126 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
127 >();
128
129 addInterfaces<SPIRVInlinerInterface>();
130
131 // Allow unknown operations because SPIR-V is extensible.
132 allowUnknownOperations();
133 }
134
getAttributeName(Decoration decoration)135 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
136 return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
137 }
138
139 //===----------------------------------------------------------------------===//
140 // Type Parsing
141 //===----------------------------------------------------------------------===//
142
143 // Forward declarations.
144 template <typename ValTy>
145 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
146 DialectAsmParser &parser);
147 template <>
148 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
149 DialectAsmParser &parser);
150
151 template <>
152 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
153 DialectAsmParser &parser);
154
parseAndVerifyType(SPIRVDialect const & dialect,DialectAsmParser & parser)155 static Type parseAndVerifyType(SPIRVDialect const &dialect,
156 DialectAsmParser &parser) {
157 Type type;
158 SMLoc typeLoc = parser.getCurrentLocation();
159 if (parser.parseType(type))
160 return Type();
161
162 // Allow SPIR-V dialect types
163 if (&type.getDialect() == &dialect)
164 return type;
165
166 // Check other allowed types
167 if (auto t = type.dyn_cast<FloatType>()) {
168 if (type.isBF16()) {
169 parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
170 return Type();
171 }
172 } else if (auto t = type.dyn_cast<IntegerType>()) {
173 if (!ScalarType::isValid(t)) {
174 parser.emitError(typeLoc,
175 "only 1/8/16/32/64-bit integer type allowed but found ")
176 << type;
177 return Type();
178 }
179 } else if (auto t = type.dyn_cast<VectorType>()) {
180 if (t.getRank() != 1) {
181 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
182 return Type();
183 }
184 if (t.getNumElements() > 4) {
185 parser.emitError(
186 typeLoc, "vector length has to be less than or equal to 4 but found ")
187 << t.getNumElements();
188 return Type();
189 }
190 } else {
191 parser.emitError(typeLoc, "cannot use ")
192 << type << " to compose SPIR-V types";
193 return Type();
194 }
195
196 return type;
197 }
198
parseAndVerifyMatrixType(SPIRVDialect const & dialect,DialectAsmParser & parser)199 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
200 DialectAsmParser &parser) {
201 Type type;
202 SMLoc typeLoc = parser.getCurrentLocation();
203 if (parser.parseType(type))
204 return Type();
205
206 if (auto t = type.dyn_cast<VectorType>()) {
207 if (t.getRank() != 1) {
208 parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
209 return Type();
210 }
211 if (t.getNumElements() > 4 || t.getNumElements() < 2) {
212 parser.emitError(typeLoc,
213 "matrix columns size has to be less than or equal "
214 "to 4 and greater than or equal 2, but found ")
215 << t.getNumElements();
216 return Type();
217 }
218
219 if (!t.getElementType().isa<FloatType>()) {
220 parser.emitError(typeLoc, "matrix columns' elements must be of "
221 "Float type, got ")
222 << t.getElementType();
223 return Type();
224 }
225 } else {
226 parser.emitError(typeLoc, "matrix must be composed using vector "
227 "type, got ")
228 << type;
229 return Type();
230 }
231
232 return type;
233 }
234
parseAndVerifySampledImageType(SPIRVDialect const & dialect,DialectAsmParser & parser)235 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
236 DialectAsmParser &parser) {
237 Type type;
238 SMLoc typeLoc = parser.getCurrentLocation();
239 if (parser.parseType(type))
240 return Type();
241
242 if (!type.isa<ImageType>()) {
243 parser.emitError(typeLoc,
244 "sampled image must be composed using image type, got ")
245 << type;
246 return Type();
247 }
248
249 return type;
250 }
251
252 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
253 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
254 /// missing.
parseOptionalArrayStride(const SPIRVDialect & dialect,DialectAsmParser & parser,unsigned & stride)255 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
256 DialectAsmParser &parser,
257 unsigned &stride) {
258 if (failed(parser.parseOptionalComma())) {
259 stride = 0;
260 return success();
261 }
262
263 if (parser.parseKeyword("stride") || parser.parseEqual())
264 return failure();
265
266 SMLoc strideLoc = parser.getCurrentLocation();
267 Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
268 if (!optStride)
269 return failure();
270
271 if (!(stride = *optStride)) {
272 parser.emitError(strideLoc, "ArrayStride must be greater than zero");
273 return failure();
274 }
275 return success();
276 }
277
278 // element-type ::= integer-type
279 // | floating-point-type
280 // | vector-type
281 // | spirv-type
282 //
283 // array-type ::= `!spv.array` `<` integer-literal `x` element-type
284 // (`,` `stride` `=` integer-literal)? `>`
parseArrayType(SPIRVDialect const & dialect,DialectAsmParser & parser)285 static Type parseArrayType(SPIRVDialect const &dialect,
286 DialectAsmParser &parser) {
287 if (parser.parseLess())
288 return Type();
289
290 SmallVector<int64_t, 1> countDims;
291 SMLoc countLoc = parser.getCurrentLocation();
292 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
293 return Type();
294 if (countDims.size() != 1) {
295 parser.emitError(countLoc,
296 "expected single integer for array element count");
297 return Type();
298 }
299
300 // According to the SPIR-V spec:
301 // "Length is the number of elements in the array. It must be at least 1."
302 int64_t count = countDims[0];
303 if (count == 0) {
304 parser.emitError(countLoc, "expected array length greater than 0");
305 return Type();
306 }
307
308 Type elementType = parseAndVerifyType(dialect, parser);
309 if (!elementType)
310 return Type();
311
312 unsigned stride = 0;
313 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
314 return Type();
315
316 if (parser.parseGreater())
317 return Type();
318 return ArrayType::get(elementType, count, stride);
319 }
320
321 // cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
322 // rows ',' columns>`
parseCooperativeMatrixType(SPIRVDialect const & dialect,DialectAsmParser & parser)323 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
324 DialectAsmParser &parser) {
325 if (parser.parseLess())
326 return Type();
327
328 SmallVector<int64_t, 2> dims;
329 SMLoc countLoc = parser.getCurrentLocation();
330 if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
331 return Type();
332
333 if (dims.size() != 2) {
334 parser.emitError(countLoc, "expected rows and columns size");
335 return Type();
336 }
337
338 auto elementTy = parseAndVerifyType(dialect, parser);
339 if (!elementTy)
340 return Type();
341
342 Scope scope;
343 if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
344 return Type();
345
346 if (parser.parseGreater())
347 return Type();
348 return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
349 }
350
351 // TODO: Reorder methods to be utilities first and parse*Type
352 // methods in alphabetical order
353 //
354 // storage-class ::= `UniformConstant`
355 // | `Uniform`
356 // | `Workgroup`
357 // | <and other storage classes...>
358 //
359 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
parsePointerType(SPIRVDialect const & dialect,DialectAsmParser & parser)360 static Type parsePointerType(SPIRVDialect const &dialect,
361 DialectAsmParser &parser) {
362 if (parser.parseLess())
363 return Type();
364
365 auto pointeeType = parseAndVerifyType(dialect, parser);
366 if (!pointeeType)
367 return Type();
368
369 StringRef storageClassSpec;
370 SMLoc storageClassLoc = parser.getCurrentLocation();
371 if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
372 return Type();
373
374 auto storageClass = symbolizeStorageClass(storageClassSpec);
375 if (!storageClass) {
376 parser.emitError(storageClassLoc, "unknown storage class: ")
377 << storageClassSpec;
378 return Type();
379 }
380 if (parser.parseGreater())
381 return Type();
382 return PointerType::get(pointeeType, *storageClass);
383 }
384
385 // runtime-array-type ::= `!spv.rtarray` `<` element-type
386 // (`,` `stride` `=` integer-literal)? `>`
parseRuntimeArrayType(SPIRVDialect const & dialect,DialectAsmParser & parser)387 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
388 DialectAsmParser &parser) {
389 if (parser.parseLess())
390 return Type();
391
392 Type elementType = parseAndVerifyType(dialect, parser);
393 if (!elementType)
394 return Type();
395
396 unsigned stride = 0;
397 if (failed(parseOptionalArrayStride(dialect, parser, stride)))
398 return Type();
399
400 if (parser.parseGreater())
401 return Type();
402 return RuntimeArrayType::get(elementType, stride);
403 }
404
405 // matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
parseMatrixType(SPIRVDialect const & dialect,DialectAsmParser & parser)406 static Type parseMatrixType(SPIRVDialect const &dialect,
407 DialectAsmParser &parser) {
408 if (parser.parseLess())
409 return Type();
410
411 SmallVector<int64_t, 1> countDims;
412 SMLoc countLoc = parser.getCurrentLocation();
413 if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
414 return Type();
415 if (countDims.size() != 1) {
416 parser.emitError(countLoc, "expected single unsigned "
417 "integer for number of columns");
418 return Type();
419 }
420
421 int64_t columnCount = countDims[0];
422 // According to the specification, Matrices can have 2, 3, or 4 columns
423 if (columnCount < 2 || columnCount > 4) {
424 parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
425 "columns");
426 return Type();
427 }
428
429 Type columnType = parseAndVerifyMatrixType(dialect, parser);
430 if (!columnType)
431 return Type();
432
433 if (parser.parseGreater())
434 return Type();
435
436 return MatrixType::get(columnType, columnCount);
437 }
438
439 // Specialize this function to parse each of the parameters that define an
440 // ImageType. By default it assumes this is an enum type.
441 template <typename ValTy>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)442 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
443 DialectAsmParser &parser) {
444 StringRef enumSpec;
445 SMLoc enumLoc = parser.getCurrentLocation();
446 if (parser.parseKeyword(&enumSpec)) {
447 return llvm::None;
448 }
449
450 auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
451 if (!val)
452 parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
453 return val;
454 }
455
456 template <>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)457 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
458 DialectAsmParser &parser) {
459 // TODO: Further verify that the element type can be sampled
460 auto ty = parseAndVerifyType(dialect, parser);
461 if (!ty)
462 return llvm::None;
463 return ty;
464 }
465
466 template <typename IntTy>
parseAndVerifyInteger(SPIRVDialect const & dialect,DialectAsmParser & parser)467 static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
468 DialectAsmParser &parser) {
469 IntTy offsetVal = std::numeric_limits<IntTy>::max();
470 if (parser.parseInteger(offsetVal))
471 return llvm::None;
472 return offsetVal;
473 }
474
475 template <>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)476 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
477 DialectAsmParser &parser) {
478 return parseAndVerifyInteger<unsigned>(dialect, parser);
479 }
480
481 namespace {
482 // Functor object to parse a comma separated list of specs. The function
483 // parseAndVerify does the actual parsing and verification of individual
484 // elements. This is a functor since parsing the last element of the list
485 // (termination condition) needs partial specialization.
486 template <typename ParseType, typename... Args>
487 struct ParseCommaSeparatedList {
488 Optional<std::tuple<ParseType, Args...>>
operator ()__anon52aeee350311::ParseCommaSeparatedList489 operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
490 auto parseVal = parseAndVerify<ParseType>(dialect, parser);
491 if (!parseVal)
492 return llvm::None;
493
494 auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
495 if (numArgs != 0 && failed(parser.parseComma()))
496 return llvm::None;
497 auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
498 if (!remainingValues)
499 return llvm::None;
500 return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
501 remainingValues.value());
502 }
503 };
504
505 // Partial specialization of the function to parse a comma separated list of
506 // specs to parse the last element of the list.
507 template <typename ParseType>
508 struct ParseCommaSeparatedList<ParseType> {
operator ()__anon52aeee350311::ParseCommaSeparatedList509 Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
510 DialectAsmParser &parser) const {
511 if (auto value = parseAndVerify<ParseType>(dialect, parser))
512 return std::tuple<ParseType>(*value);
513 return llvm::None;
514 }
515 };
516 } // namespace
517
518 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
519 //
520 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
521 //
522 // arrayed-info ::= `NonArrayed` | `Arrayed`
523 //
524 // sampling-info ::= `SingleSampled` | `MultiSampled`
525 //
526 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler`
527 //
528 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
529 //
530 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
531 // arrayed-info `,` sampling-info `,`
532 // sampler-use-info `,` format `>`
parseImageType(SPIRVDialect const & dialect,DialectAsmParser & parser)533 static Type parseImageType(SPIRVDialect const &dialect,
534 DialectAsmParser &parser) {
535 if (parser.parseLess())
536 return Type();
537
538 auto value =
539 ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
540 ImageSamplingInfo, ImageSamplerUseInfo,
541 ImageFormat>{}(dialect, parser);
542 if (!value)
543 return Type();
544
545 if (parser.parseGreater())
546 return Type();
547 return ImageType::get(*value);
548 }
549
550 // sampledImage-type :: = `!spv.sampledImage<` image-type `>`
parseSampledImageType(SPIRVDialect const & dialect,DialectAsmParser & parser)551 static Type parseSampledImageType(SPIRVDialect const &dialect,
552 DialectAsmParser &parser) {
553 if (parser.parseLess())
554 return Type();
555
556 Type parsedType = parseAndVerifySampledImageType(dialect, parser);
557 if (!parsedType)
558 return Type();
559
560 if (parser.parseGreater())
561 return Type();
562 return SampledImageType::get(parsedType);
563 }
564
565 // Parse decorations associated with a member.
parseStructMemberDecorations(SPIRVDialect const & dialect,DialectAsmParser & parser,ArrayRef<Type> memberTypes,SmallVectorImpl<StructType::OffsetInfo> & offsetInfo,SmallVectorImpl<StructType::MemberDecorationInfo> & memberDecorationInfo)566 static ParseResult parseStructMemberDecorations(
567 SPIRVDialect const &dialect, DialectAsmParser &parser,
568 ArrayRef<Type> memberTypes,
569 SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
570 SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
571
572 // Check if the first element is offset.
573 SMLoc offsetLoc = parser.getCurrentLocation();
574 StructType::OffsetInfo offset = 0;
575 OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
576 if (offsetParseResult.hasValue()) {
577 if (failed(*offsetParseResult))
578 return failure();
579
580 if (offsetInfo.size() != memberTypes.size() - 1) {
581 return parser.emitError(offsetLoc,
582 "offset specification must be given for "
583 "all members");
584 }
585 offsetInfo.push_back(offset);
586 }
587
588 // Check for no spirv::Decorations.
589 if (succeeded(parser.parseOptionalRSquare()))
590 return success();
591
592 // If there was an offset, make sure to parse the comma.
593 if (offsetParseResult.hasValue() && parser.parseComma())
594 return failure();
595
596 // Check for spirv::Decorations.
597 auto parseDecorations = [&]() {
598 auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
599 if (!memberDecoration)
600 return failure();
601
602 // Parse member decoration value if it exists.
603 if (succeeded(parser.parseOptionalEqual())) {
604 auto memberDecorationValue =
605 parseAndVerifyInteger<uint32_t>(dialect, parser);
606
607 if (!memberDecorationValue)
608 return failure();
609
610 memberDecorationInfo.emplace_back(
611 static_cast<uint32_t>(memberTypes.size() - 1), 1,
612 memberDecoration.value(), memberDecorationValue.value());
613 } else {
614 memberDecorationInfo.emplace_back(
615 static_cast<uint32_t>(memberTypes.size() - 1), 0,
616 memberDecoration.value(), 0);
617 }
618 return success();
619 };
620 if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
621 failed(parser.parseRSquare()))
622 return failure();
623
624 return success();
625 }
626
627 // struct-member-decoration ::= integer-literal? spirv-decoration*
628 // struct-type ::=
629 // `!spv.struct<` (id `,`)?
630 // `(`
631 // (spirv-type (`[` struct-member-decoration `]`)?)*
632 // `)>`
parseStructType(SPIRVDialect const & dialect,DialectAsmParser & parser)633 static Type parseStructType(SPIRVDialect const &dialect,
634 DialectAsmParser &parser) {
635 // TODO: This function is quite lengthy. Break it down into smaller chunks.
636
637 // To properly resolve recursive references while parsing recursive struct
638 // types, we need to maintain a list of enclosing struct type names. This set
639 // maintains the names of struct types in which the type we are about to parse
640 // is nested.
641 //
642 // Note: This has to be thread_local to enable multiple threads to safely
643 // parse concurrently.
644 thread_local SetVector<StringRef> structContext;
645
646 static auto removeIdentifierAndFail = [](SetVector<StringRef> &structContext,
647 StringRef identifier) {
648 if (!identifier.empty())
649 structContext.remove(identifier);
650
651 return Type();
652 };
653
654 if (parser.parseLess())
655 return Type();
656
657 StringRef identifier;
658
659 // Check if this is an identified struct type.
660 if (succeeded(parser.parseOptionalKeyword(&identifier))) {
661 // Check if this is a possible recursive reference.
662 if (succeeded(parser.parseOptionalGreater())) {
663 if (structContext.count(identifier) == 0) {
664 parser.emitError(
665 parser.getNameLoc(),
666 "recursive struct reference not nested in struct definition");
667
668 return Type();
669 }
670
671 return StructType::getIdentified(dialect.getContext(), identifier);
672 }
673
674 if (failed(parser.parseComma()))
675 return Type();
676
677 if (structContext.count(identifier) != 0) {
678 parser.emitError(parser.getNameLoc(),
679 "identifier already used for an enclosing struct");
680
681 return removeIdentifierAndFail(structContext, identifier);
682 }
683
684 structContext.insert(identifier);
685 }
686
687 if (failed(parser.parseLParen()))
688 return removeIdentifierAndFail(structContext, identifier);
689
690 if (succeeded(parser.parseOptionalRParen()) &&
691 succeeded(parser.parseOptionalGreater())) {
692 if (!identifier.empty())
693 structContext.remove(identifier);
694
695 return StructType::getEmpty(dialect.getContext(), identifier);
696 }
697
698 StructType idStructTy;
699
700 if (!identifier.empty())
701 idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
702
703 SmallVector<Type, 4> memberTypes;
704 SmallVector<StructType::OffsetInfo, 4> offsetInfo;
705 SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
706
707 do {
708 Type memberType;
709 if (parser.parseType(memberType))
710 return removeIdentifierAndFail(structContext, identifier);
711 memberTypes.push_back(memberType);
712
713 if (succeeded(parser.parseOptionalLSquare()))
714 if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
715 memberDecorationInfo))
716 return removeIdentifierAndFail(structContext, identifier);
717 } while (succeeded(parser.parseOptionalComma()));
718
719 if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
720 parser.emitError(parser.getNameLoc(),
721 "offset specification must be given for all members");
722 return removeIdentifierAndFail(structContext, identifier);
723 }
724
725 if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
726 return removeIdentifierAndFail(structContext, identifier);
727
728 if (!identifier.empty()) {
729 if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
730 memberDecorationInfo)))
731 return Type();
732
733 structContext.remove(identifier);
734 return idStructTy;
735 }
736
737 return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
738 }
739
740 // spirv-type ::= array-type
741 // | element-type
742 // | image-type
743 // | pointer-type
744 // | runtime-array-type
745 // | sampled-image-type
746 // | struct-type
parseType(DialectAsmParser & parser) const747 Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
748 StringRef keyword;
749 if (parser.parseKeyword(&keyword))
750 return Type();
751
752 if (keyword == "array")
753 return parseArrayType(*this, parser);
754 if (keyword == "coopmatrix")
755 return parseCooperativeMatrixType(*this, parser);
756 if (keyword == "image")
757 return parseImageType(*this, parser);
758 if (keyword == "ptr")
759 return parsePointerType(*this, parser);
760 if (keyword == "rtarray")
761 return parseRuntimeArrayType(*this, parser);
762 if (keyword == "sampled_image")
763 return parseSampledImageType(*this, parser);
764 if (keyword == "struct")
765 return parseStructType(*this, parser);
766 if (keyword == "matrix")
767 return parseMatrixType(*this, parser);
768 parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
769 return Type();
770 }
771
772 //===----------------------------------------------------------------------===//
773 // Type Printing
774 //===----------------------------------------------------------------------===//
775
print(ArrayType type,DialectAsmPrinter & os)776 static void print(ArrayType type, DialectAsmPrinter &os) {
777 os << "array<" << type.getNumElements() << " x " << type.getElementType();
778 if (unsigned stride = type.getArrayStride())
779 os << ", stride=" << stride;
780 os << ">";
781 }
782
print(RuntimeArrayType type,DialectAsmPrinter & os)783 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
784 os << "rtarray<" << type.getElementType();
785 if (unsigned stride = type.getArrayStride())
786 os << ", stride=" << stride;
787 os << ">";
788 }
789
print(PointerType type,DialectAsmPrinter & os)790 static void print(PointerType type, DialectAsmPrinter &os) {
791 os << "ptr<" << type.getPointeeType() << ", "
792 << stringifyStorageClass(type.getStorageClass()) << ">";
793 }
794
print(ImageType type,DialectAsmPrinter & os)795 static void print(ImageType type, DialectAsmPrinter &os) {
796 os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
797 << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
798 << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
799 << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
800 << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
801 << stringifyImageFormat(type.getImageFormat()) << ">";
802 }
803
print(SampledImageType type,DialectAsmPrinter & os)804 static void print(SampledImageType type, DialectAsmPrinter &os) {
805 os << "sampled_image<" << type.getImageType() << ">";
806 }
807
print(StructType type,DialectAsmPrinter & os)808 static void print(StructType type, DialectAsmPrinter &os) {
809 thread_local SetVector<StringRef> structContext;
810
811 os << "struct<";
812
813 if (type.isIdentified()) {
814 os << type.getIdentifier();
815
816 if (structContext.count(type.getIdentifier())) {
817 os << ">";
818 return;
819 }
820
821 os << ", ";
822 structContext.insert(type.getIdentifier());
823 }
824
825 os << "(";
826
827 auto printMember = [&](unsigned i) {
828 os << type.getElementType(i);
829 SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
830 type.getMemberDecorations(i, decorations);
831 if (type.hasOffset() || !decorations.empty()) {
832 os << " [";
833 if (type.hasOffset()) {
834 os << type.getMemberOffset(i);
835 if (!decorations.empty())
836 os << ", ";
837 }
838 auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
839 os << stringifyDecoration(decoration.decoration);
840 if (decoration.hasValue) {
841 os << "=" << decoration.decorationValue;
842 }
843 };
844 llvm::interleaveComma(decorations, os, eachFn);
845 os << "]";
846 }
847 };
848 llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
849 printMember);
850 os << ")>";
851
852 if (type.isIdentified())
853 structContext.remove(type.getIdentifier());
854 }
855
print(CooperativeMatrixNVType type,DialectAsmPrinter & os)856 static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
857 os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
858 os << type.getElementType() << ", " << stringifyScope(type.getScope());
859 os << ">";
860 }
861
print(MatrixType type,DialectAsmPrinter & os)862 static void print(MatrixType type, DialectAsmPrinter &os) {
863 os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
864 os << ">";
865 }
866
printType(Type type,DialectAsmPrinter & os) const867 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
868 TypeSwitch<Type>(type)
869 .Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
870 ImageType, SampledImageType, StructType, MatrixType>(
871 [&](auto type) { print(type, os); })
872 .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
873 }
874
875 //===----------------------------------------------------------------------===//
876 // Constant
877 //===----------------------------------------------------------------------===//
878
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)879 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
880 Attribute value, Type type,
881 Location loc) {
882 if (!spirv::ConstantOp::isBuildableWith(type))
883 return nullptr;
884
885 return builder.create<spirv::ConstantOp>(loc, type, value);
886 }
887
888 //===----------------------------------------------------------------------===//
889 // Shader Interface ABI
890 //===----------------------------------------------------------------------===//
891
verifyOperationAttribute(Operation * op,NamedAttribute attribute)892 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
893 NamedAttribute attribute) {
894 StringRef symbol = attribute.getName().strref();
895 Attribute attr = attribute.getValue();
896
897 if (symbol == spirv::getEntryPointABIAttrName()) {
898 if (!attr.isa<spirv::EntryPointABIAttr>()) {
899 return op->emitError("'")
900 << symbol << "' attribute must be an entry point ABI attribute";
901 }
902 } else if (symbol == spirv::getTargetEnvAttrName()) {
903 if (!attr.isa<spirv::TargetEnvAttr>())
904 return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
905 } else {
906 return op->emitError("found unsupported '")
907 << symbol << "' attribute on operation";
908 }
909
910 return success();
911 }
912
913 /// Verifies the given SPIR-V `attribute` attached to a value of the given
914 /// `valueType` is valid.
verifyRegionAttribute(Location loc,Type valueType,NamedAttribute attribute)915 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
916 NamedAttribute attribute) {
917 StringRef symbol = attribute.getName().strref();
918 Attribute attr = attribute.getValue();
919
920 if (symbol != spirv::getInterfaceVarABIAttrName())
921 return emitError(loc, "found unsupported '")
922 << symbol << "' attribute on region argument";
923
924 auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
925 if (!varABIAttr)
926 return emitError(loc, "'")
927 << symbol << "' must be a spirv::InterfaceVarABIAttr";
928
929 if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
930 return emitError(loc, "'") << symbol
931 << "' attribute cannot specify storage class "
932 "when attaching to a non-scalar value";
933
934 return success();
935 }
936
verifyRegionArgAttribute(Operation * op,unsigned regionIndex,unsigned argIndex,NamedAttribute attribute)937 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
938 unsigned regionIndex,
939 unsigned argIndex,
940 NamedAttribute attribute) {
941 return verifyRegionAttribute(
942 op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
943 attribute);
944 }
945
verifyRegionResultAttribute(Operation * op,unsigned,unsigned,NamedAttribute attribute)946 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
947 Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
948 NamedAttribute attribute) {
949 return op->emitError("cannot attach SPIR-V attributes to region result");
950 }
951