1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <functional>
22 
23 #define DEBUG_TYPE "mlir-spirv-conversion"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Checks that `candidates` extension requirements are possible to be satisfied
32 /// with the given `targetEnv`.
33 ///
34 ///  `candidates` is a vector of vector for extension requirements following
35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
36 /// convention.
37 template <typename LabelT>
38 static LogicalResult checkExtensionRequirements(
39     LabelT label, const spirv::TargetEnv &targetEnv,
40     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
41   for (const auto &ors : candidates) {
42     if (targetEnv.allows(ors))
43       continue;
44 
45     SmallVector<StringRef, 4> extStrings;
46     for (spirv::Extension ext : ors)
47       extStrings.push_back(spirv::stringifyExtension(ext));
48 
49     LLVM_DEBUG(llvm::dbgs()
50                << label << " illegal: requires at least one extension in ["
51                << llvm::join(extStrings, ", ")
52                << "] but none allowed in target environment\n");
53     return failure();
54   }
55   return success();
56 }
57 
58 /// Checks that `candidates`capability requirements are possible to be satisfied
59 /// with the given `isAllowedFn`.
60 ///
61 ///  `candidates` is a vector of vector for capability requirements following
62 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
63 /// convention.
64 template <typename LabelT>
65 static LogicalResult checkCapabilityRequirements(
66     LabelT label, const spirv::TargetEnv &targetEnv,
67     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
68   for (const auto &ors : candidates) {
69     if (targetEnv.allows(ors))
70       continue;
71 
72     SmallVector<StringRef, 4> capStrings;
73     for (spirv::Capability cap : ors)
74       capStrings.push_back(spirv::stringifyCapability(cap));
75 
76     LLVM_DEBUG(llvm::dbgs()
77                << label << " illegal: requires at least one capability in ["
78                << llvm::join(capStrings, ", ")
79                << "] but none allowed in target environment\n");
80     return failure();
81   }
82   return success();
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Type Conversion
87 //===----------------------------------------------------------------------===//
88 
89 Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
90   // Convert to 32-bit integers for now. Might need a way to control this in
91   // future.
92   // TODO: It is probably better to make it 64-bit integers. To
93   // this some support is needed in SPIR-V dialect for Conversion
94   // instructions. The Vulkan spec requires the builtins like
95   // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
96   // SExtended to 64-bit for index computations.
97   return IntegerType::get(context, 32);
98 }
99 
100 /// Mapping between SPIR-V storage classes to memref memory spaces.
101 ///
102 /// Note: memref does not have a defined semantics for each memory space; it
103 /// depends on the context where it is used. There are no particular reasons
104 /// behind the number assignments; we try to follow NVVM conventions and largely
105 /// give common storage classes a smaller number. The hope is use symbolic
106 /// memory space representation eventually after memref supports it.
107 // TODO: swap Generic and StorageBuffer assignment to be more akin
108 // to NVVM.
109 #define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
110   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
111   MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
112   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
113   MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
114   MAP_FN(spirv::StorageClass::Private, 5)                                      \
115   MAP_FN(spirv::StorageClass::Function, 6)                                     \
116   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
117   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
118   MAP_FN(spirv::StorageClass::Input, 9)                                        \
119   MAP_FN(spirv::StorageClass::Output, 10)                                      \
120   MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
121   MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
122   MAP_FN(spirv::StorageClass::Image, 13)                                       \
123   MAP_FN(spirv::StorageClass::CallableDataNV, 14)                              \
124   MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15)                      \
125   MAP_FN(spirv::StorageClass::RayPayloadNV, 16)                                \
126   MAP_FN(spirv::StorageClass::HitAttributeNV, 17)                              \
127   MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18)                        \
128   MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19)                        \
129   MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)
130 
131 unsigned
132 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
133 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
134   case storage:                                                                \
135     return space;
136 
137   switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
138 #undef STORAGE_SPACE_MAP_FN
139   llvm_unreachable("unhandled storage class!");
140 }
141 
142 Optional<spirv::StorageClass>
143 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
144 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
145   case space:                                                                  \
146     return storage;
147 
148   switch (space) {
149     STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
150   default:
151     return llvm::None;
152   }
153 #undef STORAGE_SPACE_MAP_FN
154 }
155 
156 #undef STORAGE_SPACE_MAP_LIST
157 
158 // TODO: This is a utility function that should probably be
159 // exposed by the SPIR-V dialect. Keeping it local till the use case arises.
160 static Optional<int64_t> getTypeNumBytes(Type t) {
161   if (t.isa<spirv::ScalarType>()) {
162     auto bitWidth = t.getIntOrFloatBitWidth();
163     // According to the SPIR-V spec:
164     // "There is no physical size or bit pattern defined for values with boolean
165     // type. If they are stored (in conjunction with OpVariable), they can only
166     // be used with logical addressing operations, not physical, and only with
167     // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
168     // Private, Function, Input, and Output."
169     if (bitWidth == 1) {
170       return llvm::None;
171     }
172     return bitWidth / 8;
173   }
174 
175   if (auto vecType = t.dyn_cast<VectorType>()) {
176     auto elementSize = getTypeNumBytes(vecType.getElementType());
177     if (!elementSize)
178       return llvm::None;
179     return vecType.getNumElements() * *elementSize;
180   }
181 
182   if (auto memRefType = t.dyn_cast<MemRefType>()) {
183     // TODO: Layout should also be controlled by the ABI attributes. For now
184     // using the layout from MemRef.
185     int64_t offset;
186     SmallVector<int64_t, 4> strides;
187     if (!memRefType.hasStaticShape() ||
188         failed(getStridesAndOffset(memRefType, strides, offset))) {
189       return llvm::None;
190     }
191     // To get the size of the memref object in memory, the total size is the
192     // max(stride * dimension-size) computed for all dimensions times the size
193     // of the element.
194     auto elementSize = getTypeNumBytes(memRefType.getElementType());
195     if (!elementSize) {
196       return llvm::None;
197     }
198     if (memRefType.getRank() == 0) {
199       return elementSize;
200     }
201     auto dims = memRefType.getShape();
202     if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
203         offset == MemRefType::getDynamicStrideOrOffset() ||
204         llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
205       return llvm::None;
206     }
207     int64_t memrefSize = -1;
208     for (auto shape : enumerate(dims)) {
209       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
210     }
211     return (offset + memrefSize) * elementSize.getValue();
212   }
213 
214   if (auto tensorType = t.dyn_cast<TensorType>()) {
215     if (!tensorType.hasStaticShape()) {
216       return llvm::None;
217     }
218     auto elementSize = getTypeNumBytes(tensorType.getElementType());
219     if (!elementSize) {
220       return llvm::None;
221     }
222     int64_t size = elementSize.getValue();
223     for (auto shape : tensorType.getShape()) {
224       size *= shape;
225     }
226     return size;
227   }
228 
229   // TODO: Add size computation for other types.
230   return llvm::None;
231 }
232 
233 Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
234   return getTypeNumBytes(t);
235 }
236 
237 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
238 static Optional<Type>
239 convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
240                   Optional<spirv::StorageClass> storageClass = {}) {
241   // Get extension and capability requirements for the given type.
242   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
243   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
244   type.getExtensions(extensions, storageClass);
245   type.getCapabilities(capabilities, storageClass);
246 
247   // If all requirements are met, then we can accept this type as-is.
248   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
249       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
250     return type;
251 
252   // Otherwise we need to adjust the type, which really means adjusting the
253   // bitwidth given this is a scalar type.
254   // TODO: We are unconditionally converting the bitwidth here,
255   // this might be okay for non-interface types (i.e., types used in
256   // Private/Function storage classes), but not for interface types (i.e.,
257   // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes).
258   // This is because the later actually affects the ABI contract with the
259   // runtime. So we may want to expose a control on SPIRVTypeConverter to fail
260   // conversion if we cannot change there.
261 
262   if (auto floatType = type.dyn_cast<FloatType>()) {
263     LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
264     return Builder(targetEnv.getContext()).getF32Type();
265   }
266 
267   auto intType = type.cast<IntegerType>();
268   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
269   return IntegerType::get(targetEnv.getContext(), /*width=*/32,
270                           intType.getSignedness());
271 }
272 
273 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
274 static Optional<Type>
275 convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
276                   Optional<spirv::StorageClass> storageClass = {}) {
277   if (type.getRank() == 1 && type.getNumElements() == 1)
278     return type.getElementType();
279 
280   if (!spirv::CompositeType::isValid(type)) {
281     // TODO: Vector types with more than four elements can be translated into
282     // array types.
283     LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
284     return llvm::None;
285   }
286 
287   // Get extension and capability requirements for the given type.
288   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
289   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
290   type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
291   type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
292 
293   // If all requirements are met, then we can accept this type as-is.
294   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
295       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
296     return type;
297 
298   auto elementType = convertScalarType(
299       targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
300   if (elementType)
301     return VectorType::get(type.getShape(), *elementType);
302   return llvm::None;
303 }
304 
305 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
306 ///
307 /// Note that this is mainly for lowering constant tensors.In SPIR-V one can
308 /// create composite constants with OpConstantComposite to embed relative large
309 /// constant values and use OpCompositeExtract and OpCompositeInsert to
310 /// manipulate, like what we do for vectors.
311 static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
312                                         TensorType type) {
313   // TODO: Handle dynamic shapes.
314   if (!type.hasStaticShape()) {
315     LLVM_DEBUG(llvm::dbgs()
316                << type << " illegal: dynamic shape unimplemented\n");
317     return llvm::None;
318   }
319 
320   auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
321   if (!scalarType) {
322     LLVM_DEBUG(llvm::dbgs()
323                << type << " illegal: cannot convert non-scalar element type\n");
324     return llvm::None;
325   }
326 
327   Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
328   Optional<int64_t> tensorSize = getTypeNumBytes(type);
329   if (!scalarSize || !tensorSize) {
330     LLVM_DEBUG(llvm::dbgs()
331                << type << " illegal: cannot deduce element count\n");
332     return llvm::None;
333   }
334 
335   auto arrayElemCount = *tensorSize / *scalarSize;
336   auto arrayElemType = convertScalarType(targetEnv, scalarType);
337   if (!arrayElemType)
338     return llvm::None;
339   Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
340   if (!arrayElemSize) {
341     LLVM_DEBUG(llvm::dbgs()
342                << type << " illegal: cannot deduce converted element size\n");
343     return llvm::None;
344   }
345 
346   return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
347 }
348 
349 static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
350                                         MemRefType type) {
351   Optional<spirv::StorageClass> storageClass =
352       SPIRVTypeConverter::getStorageClassForMemorySpace(
353           type.getMemorySpaceAsInt());
354   if (!storageClass) {
355     LLVM_DEBUG(llvm::dbgs()
356                << type << " illegal: cannot convert memory space\n");
357     return llvm::None;
358   }
359 
360   Optional<Type> arrayElemType;
361   Type elementType = type.getElementType();
362   if (auto vecType = elementType.dyn_cast<VectorType>()) {
363     arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
364   } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
365     arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
366   } else {
367     LLVM_DEBUG(
368         llvm::dbgs()
369         << type
370         << " unhandled: can only convert scalar or vector element type\n");
371     return llvm::None;
372   }
373   if (!arrayElemType)
374     return llvm::None;
375 
376   Optional<int64_t> elementSize = getTypeNumBytes(elementType);
377   if (!elementSize) {
378     LLVM_DEBUG(llvm::dbgs()
379                << type << " illegal: cannot deduce element size\n");
380     return llvm::None;
381   }
382 
383   if (!type.hasStaticShape()) {
384     auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
385     // Wrap in a struct to satisfy Vulkan interface requirements.
386     auto structType = spirv::StructType::get(arrayType, 0);
387     return spirv::PointerType::get(structType, *storageClass);
388   }
389 
390   Optional<int64_t> memrefSize = getTypeNumBytes(type);
391   if (!memrefSize) {
392     LLVM_DEBUG(llvm::dbgs()
393                << type << " illegal: cannot deduce element count\n");
394     return llvm::None;
395   }
396 
397   auto arrayElemCount = *memrefSize / *elementSize;
398 
399   Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
400   if (!arrayElemSize) {
401     LLVM_DEBUG(llvm::dbgs()
402                << type << " illegal: cannot deduce converted element size\n");
403     return llvm::None;
404   }
405 
406   auto arrayType =
407       spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
408 
409   // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
410   // workgroup storage class do not need the struct to be laid out explicitly.
411   auto structType = *storageClass == spirv::StorageClass::Workgroup
412                         ? spirv::StructType::get(arrayType)
413                         : spirv::StructType::get(arrayType, 0);
414   return spirv::PointerType::get(structType, *storageClass);
415 }
416 
417 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
418     : targetEnv(targetAttr) {
419   // Add conversions. The order matters here: later ones will be tried earlier.
420 
421   // All other cases failed. Then we cannot convert this type.
422   addConversion([](Type type) { return llvm::None; });
423 
424   // Allow all SPIR-V dialect specific types. This assumes all builtin types
425   // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
426   // were tried before.
427   //
428   // TODO: this assumes that the SPIR-V types are valid to use in
429   // the given target environment, which should be the case if the whole
430   // pipeline is driven by the same target environment. Still, we probably still
431   // want to validate and convert to be safe.
432   addConversion([](spirv::SPIRVType type) { return type; });
433 
434   addConversion([](IndexType indexType) {
435     return SPIRVTypeConverter::getIndexType(indexType.getContext());
436   });
437 
438   addConversion([this](IntegerType intType) -> Optional<Type> {
439     if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
440       return convertScalarType(targetEnv, scalarType);
441     return llvm::None;
442   });
443 
444   addConversion([this](FloatType floatType) -> Optional<Type> {
445     if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
446       return convertScalarType(targetEnv, scalarType);
447     return llvm::None;
448   });
449 
450   addConversion([this](VectorType vectorType) {
451     return convertVectorType(targetEnv, vectorType);
452   });
453 
454   addConversion([this](TensorType tensorType) {
455     return convertTensorType(targetEnv, tensorType);
456   });
457 
458   addConversion([this](MemRefType memRefType) {
459     return convertMemrefType(targetEnv, memRefType);
460   });
461 }
462 
463 //===----------------------------------------------------------------------===//
464 // FuncOp Conversion Patterns
465 //===----------------------------------------------------------------------===//
466 
467 namespace {
468 /// A pattern for rewriting function signature to convert arguments of functions
469 /// to be of valid SPIR-V types.
470 class FuncOpConversion final : public OpConversionPattern<FuncOp> {
471 public:
472   using OpConversionPattern<FuncOp>::OpConversionPattern;
473 
474   LogicalResult
475   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
476                   ConversionPatternRewriter &rewriter) const override;
477 };
478 } // namespace
479 
480 LogicalResult
481 FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
482                                   ConversionPatternRewriter &rewriter) const {
483   auto fnType = funcOp.getType();
484   if (fnType.getNumResults() > 1)
485     return failure();
486 
487   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
488   for (auto argType : enumerate(fnType.getInputs())) {
489     auto convertedType = getTypeConverter()->convertType(argType.value());
490     if (!convertedType)
491       return failure();
492     signatureConverter.addInputs(argType.index(), convertedType);
493   }
494 
495   Type resultType;
496   if (fnType.getNumResults() == 1)
497     resultType = getTypeConverter()->convertType(fnType.getResult(0));
498 
499   // Create the converted spv.func op.
500   auto newFuncOp = rewriter.create<spirv::FuncOp>(
501       funcOp.getLoc(), funcOp.getName(),
502       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
503                                resultType ? TypeRange(resultType)
504                                           : TypeRange()));
505 
506   // Copy over all attributes other than the function name and type.
507   for (const auto &namedAttr : funcOp->getAttrs()) {
508     if (namedAttr.first != impl::getTypeAttrName() &&
509         namedAttr.first != SymbolTable::getSymbolAttrName())
510       newFuncOp->setAttr(namedAttr.first, namedAttr.second);
511   }
512 
513   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
514                               newFuncOp.end());
515   if (failed(rewriter.convertRegionTypes(
516           &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
517     return failure();
518   rewriter.eraseOp(funcOp);
519   return success();
520 }
521 
522 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
523                                               RewritePatternSet &patterns) {
524   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // Builtin Variables
529 //===----------------------------------------------------------------------===//
530 
531 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
532                                                   spirv::BuiltIn builtin) {
533   // Look through all global variables in the given `body` block and check if
534   // there is a spv.GlobalVariable that has the same `builtin` attribute.
535   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
536     if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
537             spirv::SPIRVDialect::getAttributeName(
538                 spirv::Decoration::BuiltIn))) {
539       auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
540       if (varBuiltIn && varBuiltIn.getValue() == builtin) {
541         return varOp;
542       }
543     }
544   }
545   return nullptr;
546 }
547 
548 /// Gets name of global variable for a builtin.
549 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
550   return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
551 }
552 
553 /// Gets or inserts a global variable for a builtin within `body` block.
554 static spirv::GlobalVariableOp
555 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
556                            OpBuilder &builder) {
557   if (auto varOp = getBuiltinVariable(body, builtin))
558     return varOp;
559 
560   OpBuilder::InsertionGuard guard(builder);
561   builder.setInsertionPointToStart(&body);
562 
563   spirv::GlobalVariableOp newVarOp;
564   switch (builtin) {
565   case spirv::BuiltIn::NumWorkgroups:
566   case spirv::BuiltIn::WorkgroupSize:
567   case spirv::BuiltIn::WorkgroupId:
568   case spirv::BuiltIn::LocalInvocationId:
569   case spirv::BuiltIn::GlobalInvocationId: {
570     auto ptrType = spirv::PointerType::get(
571         VectorType::get({3}, builder.getIntegerType(32)),
572         spirv::StorageClass::Input);
573     std::string name = getBuiltinVarName(builtin);
574     newVarOp =
575         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
576     break;
577   }
578   case spirv::BuiltIn::SubgroupId:
579   case spirv::BuiltIn::NumSubgroups:
580   case spirv::BuiltIn::SubgroupSize: {
581     auto ptrType = spirv::PointerType::get(builder.getIntegerType(32),
582                                            spirv::StorageClass::Input);
583     std::string name = getBuiltinVarName(builtin);
584     newVarOp =
585         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
586     break;
587   }
588   default:
589     emitError(loc, "unimplemented builtin variable generation for ")
590         << stringifyBuiltIn(builtin);
591   }
592   return newVarOp;
593 }
594 
595 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
596                                            spirv::BuiltIn builtin,
597                                            OpBuilder &builder) {
598   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
599   if (!parent) {
600     op->emitError("expected operation to be within a module-like op");
601     return nullptr;
602   }
603 
604   spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(
605       *parent->getRegion(0).begin(), op->getLoc(), builtin, builder);
606   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
607   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
608 }
609 
610 //===----------------------------------------------------------------------===//
611 // Push constant storage
612 //===----------------------------------------------------------------------===//
613 
614 /// Returns the pointer type for the push constant storage containing
615 /// `elementCount` 32-bit integer values.
616 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
617                                                      Builder &builder) {
618   auto arrayType = spirv::ArrayType::get(
619       SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount,
620       /*stride=*/4);
621   auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
622   return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
623 }
624 
625 /// Returns the push constant varible containing `elementCount` 32-bit integer
626 /// values in `body`. Returns null op if such an op does not exit.
627 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
628                                                        unsigned elementCount) {
629   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
630     auto ptrType = varOp.type().cast<spirv::PointerType>();
631     // Note that Vulkan requires "There must be no more than one push constant
632     // block statically used per shader entry point." So we should always reuse
633     // the existing one.
634     if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
635       auto numElements = ptrType.getPointeeType()
636                              .cast<spirv::StructType>()
637                              .getElementType(0)
638                              .cast<spirv::ArrayType>()
639                              .getNumElements();
640       if (numElements == elementCount)
641         return varOp;
642     }
643   }
644   return nullptr;
645 }
646 
647 /// Gets or inserts a global variable for push constant storage containing
648 /// `elementCount` 32-bit integer values in `block`.
649 static spirv::GlobalVariableOp
650 getOrInsertPushConstantVariable(Location loc, Block &block,
651                                 unsigned elementCount, OpBuilder &b) {
652   if (auto varOp = getPushConstantVariable(block, elementCount))
653     return varOp;
654 
655   auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
656   auto type = getPushConstantStorageType(elementCount, builder);
657   const char *name = "__push_constant_var__";
658   return builder.create<spirv::GlobalVariableOp>(loc, type, name,
659                                                  /*initializer=*/nullptr);
660 }
661 
662 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
663                                   unsigned offset, OpBuilder &builder) {
664   Location loc = op->getLoc();
665   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
666   if (!parent) {
667     op->emitError("expected operation to be within a module-like op");
668     return nullptr;
669   }
670 
671   spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
672       loc, parent->getRegion(0).front(), elementCount, builder);
673 
674   auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext());
675   Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder);
676   Value offsetOp = builder.create<spirv::ConstantOp>(
677       loc, i32Type, builder.getI32IntegerAttr(offset));
678   auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
679   auto acOp = builder.create<spirv::AccessChainOp>(
680       loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
681   return builder.create<spirv::LoadOp>(loc, acOp);
682 }
683 
684 //===----------------------------------------------------------------------===//
685 // Index calculation
686 //===----------------------------------------------------------------------===//
687 
688 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
689                                   int64_t offset, Location loc,
690                                   OpBuilder &builder) {
691   assert(indices.size() == strides.size() &&
692          "must provide indices for all dimensions");
693 
694   auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext());
695 
696   // TODO: Consider moving to use affine.apply and patterns converting
697   // affine.apply to standard ops. This needs converting to SPIR-V passes to be
698   // broken down into progressive small steps so we can have intermediate steps
699   // using other dialects. At the moment SPIR-V is the final sink.
700 
701   Value linearizedIndex = builder.create<spirv::ConstantOp>(
702       loc, indexType, IntegerAttr::get(indexType, offset));
703   for (auto index : llvm::enumerate(indices)) {
704     Value strideVal = builder.create<spirv::ConstantOp>(
705         loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
706     Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
707     linearizedIndex =
708         builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
709   }
710   return linearizedIndex;
711 }
712 
713 spirv::AccessChainOp mlir::spirv::getElementPtr(
714     SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
715     ValueRange indices, Location loc, OpBuilder &builder) {
716   // Get base and offset of the MemRefType and verify they are static.
717 
718   int64_t offset;
719   SmallVector<int64_t, 4> strides;
720   if (failed(getStridesAndOffset(baseType, strides, offset)) ||
721       llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
722       offset == MemRefType::getDynamicStrideOrOffset()) {
723     return nullptr;
724   }
725 
726   auto indexType = typeConverter.getIndexType(builder.getContext());
727 
728   SmallVector<Value, 2> linearizedIndices;
729   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
730 
731   // Add a '0' at the start to index into the struct.
732   linearizedIndices.push_back(zero);
733 
734   if (baseType.getRank() == 0) {
735     linearizedIndices.push_back(zero);
736   } else {
737     linearizedIndices.push_back(
738         linearizeIndex(indices, strides, offset, loc, builder));
739   }
740   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // SPIR-V ConversionTarget
745 //===----------------------------------------------------------------------===//
746 
747 std::unique_ptr<SPIRVConversionTarget>
748 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
749   std::unique_ptr<SPIRVConversionTarget> target(
750       // std::make_unique does not work here because the constructor is private.
751       new SPIRVConversionTarget(targetAttr));
752   SPIRVConversionTarget *targetPtr = target.get();
753   target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
754       // We need to capture the raw pointer here because it is stable:
755       // target will be destroyed once this function is returned.
756       [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
757   return target;
758 }
759 
760 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
761     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
762 
763 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
764   // Make sure this op is available at the given version. Ops not implementing
765   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
766   // SPIR-V versions.
767   if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
768     if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
769       LLVM_DEBUG(llvm::dbgs()
770                  << op->getName() << " illegal: requiring min version "
771                  << spirv::stringifyVersion(minVersion.getMinVersion())
772                  << "\n");
773       return false;
774     }
775   if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
776     if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
777       LLVM_DEBUG(llvm::dbgs()
778                  << op->getName() << " illegal: requiring max version "
779                  << spirv::stringifyVersion(maxVersion.getMaxVersion())
780                  << "\n");
781       return false;
782     }
783 
784   // Make sure this op's required extensions are allowed to use. Ops not
785   // implementing QueryExtensionInterface do not require extensions to be
786   // available.
787   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
788     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
789                                           extensions.getExtensions())))
790       return false;
791 
792   // Make sure this op's required extensions are allowed to use. Ops not
793   // implementing QueryCapabilityInterface do not require capabilities to be
794   // available.
795   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
796     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
797                                            capabilities.getCapabilities())))
798       return false;
799 
800   SmallVector<Type, 4> valueTypes;
801   valueTypes.append(op->operand_type_begin(), op->operand_type_end());
802   valueTypes.append(op->result_type_begin(), op->result_type_end());
803 
804   // Special treatment for global variables, whose type requirements are
805   // conveyed by type attributes.
806   if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
807     valueTypes.push_back(globalVar.type());
808 
809   // Make sure the op's operands/results use types that are allowed by the
810   // target environment.
811   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
812   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
813   for (Type valueType : valueTypes) {
814     typeExtensions.clear();
815     valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
816     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
817                                           typeExtensions)))
818       return false;
819 
820     typeCapabilities.clear();
821     valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
822     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
823                                            typeCapabilities)))
824       return false;
825   }
826 
827   return true;
828 }
829