1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <functional>
22 
23 #define DEBUG_TYPE "mlir-spirv-conversion"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Checks that `candidates` extension requirements are possible to be satisfied
32 /// with the given `targetEnv`.
33 ///
34 ///  `candidates` is a vector of vector for extension requirements following
35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
36 /// convention.
37 template <typename LabelT>
38 static LogicalResult checkExtensionRequirements(
39     LabelT label, const spirv::TargetEnv &targetEnv,
40     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
41   for (const auto &ors : candidates) {
42     if (targetEnv.allows(ors))
43       continue;
44 
45     SmallVector<StringRef, 4> extStrings;
46     for (spirv::Extension ext : ors)
47       extStrings.push_back(spirv::stringifyExtension(ext));
48 
49     LLVM_DEBUG(llvm::dbgs()
50                << label << " illegal: requires at least one extension in ["
51                << llvm::join(extStrings, ", ")
52                << "] but none allowed in target environment\n");
53     return failure();
54   }
55   return success();
56 }
57 
58 /// Checks that `candidates`capability requirements are possible to be satisfied
59 /// with the given `isAllowedFn`.
60 ///
61 ///  `candidates` is a vector of vector for capability requirements following
62 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
63 /// convention.
64 template <typename LabelT>
65 static LogicalResult checkCapabilityRequirements(
66     LabelT label, const spirv::TargetEnv &targetEnv,
67     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
68   for (const auto &ors : candidates) {
69     if (targetEnv.allows(ors))
70       continue;
71 
72     SmallVector<StringRef, 4> capStrings;
73     for (spirv::Capability cap : ors)
74       capStrings.push_back(spirv::stringifyCapability(cap));
75 
76     LLVM_DEBUG(llvm::dbgs()
77                << label << " illegal: requires at least one capability in ["
78                << llvm::join(capStrings, ", ")
79                << "] but none allowed in target environment\n");
80     return failure();
81   }
82   return success();
83 }
84 
85 //===----------------------------------------------------------------------===//
86 // Type Conversion
87 //===----------------------------------------------------------------------===//
88 
89 Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
90   // Convert to 32-bit integers for now. Might need a way to control this in
91   // future.
92   // TODO: It is probably better to make it 64-bit integers. To
93   // this some support is needed in SPIR-V dialect for Conversion
94   // instructions. The Vulkan spec requires the builtins like
95   // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
96   // SExtended to 64-bit for index computations.
97   return IntegerType::get(context, 32);
98 }
99 
100 /// Mapping between SPIR-V storage classes to memref memory spaces.
101 ///
102 /// Note: memref does not have a defined semantics for each memory space; it
103 /// depends on the context where it is used. There are no particular reasons
104 /// behind the number assignments; we try to follow NVVM conventions and largely
105 /// give common storage classes a smaller number. The hope is use symbolic
106 /// memory space representation eventually after memref supports it.
107 // TODO: swap Generic and StorageBuffer assignment to be more akin
108 // to NVVM.
109 #define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
110   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
111   MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
112   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
113   MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
114   MAP_FN(spirv::StorageClass::Private, 5)                                      \
115   MAP_FN(spirv::StorageClass::Function, 6)                                     \
116   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
117   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
118   MAP_FN(spirv::StorageClass::Input, 9)                                        \
119   MAP_FN(spirv::StorageClass::Output, 10)                                      \
120   MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
121   MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
122   MAP_FN(spirv::StorageClass::Image, 13)                                       \
123   MAP_FN(spirv::StorageClass::CallableDataNV, 14)                              \
124   MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15)                      \
125   MAP_FN(spirv::StorageClass::RayPayloadNV, 16)                                \
126   MAP_FN(spirv::StorageClass::HitAttributeNV, 17)                              \
127   MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18)                        \
128   MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19)                        \
129   MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)
130 
131 unsigned
132 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
133 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
134   case storage:                                                                \
135     return space;
136 
137   switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
138 #undef STORAGE_SPACE_MAP_FN
139   llvm_unreachable("unhandled storage class!");
140 }
141 
142 Optional<spirv::StorageClass>
143 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
144 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
145   case space:                                                                  \
146     return storage;
147 
148   switch (space) {
149     STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
150   default:
151     return llvm::None;
152   }
153 #undef STORAGE_SPACE_MAP_FN
154 }
155 
156 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
157   return options;
158 }
159 
160 #undef STORAGE_SPACE_MAP_LIST
161 
162 // TODO: This is a utility function that should probably be exposed by the
163 // SPIR-V dialect. Keeping it local till the use case arises.
164 static Optional<int64_t>
165 getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
166   if (type.isa<spirv::ScalarType>()) {
167     auto bitWidth = type.getIntOrFloatBitWidth();
168     // According to the SPIR-V spec:
169     // "There is no physical size or bit pattern defined for values with boolean
170     // type. If they are stored (in conjunction with OpVariable), they can only
171     // be used with logical addressing operations, not physical, and only with
172     // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
173     // Private, Function, Input, and Output."
174     if (bitWidth == 1)
175       return llvm::None;
176     return bitWidth / 8;
177   }
178 
179   if (auto vecType = type.dyn_cast<VectorType>()) {
180     auto elementSize = getTypeNumBytes(options, vecType.getElementType());
181     if (!elementSize)
182       return llvm::None;
183     return vecType.getNumElements() * elementSize.getValue();
184   }
185 
186   if (auto memRefType = type.dyn_cast<MemRefType>()) {
187     // TODO: Layout should also be controlled by the ABI attributes. For now
188     // using the layout from MemRef.
189     int64_t offset;
190     SmallVector<int64_t, 4> strides;
191     if (!memRefType.hasStaticShape() ||
192         failed(getStridesAndOffset(memRefType, strides, offset)))
193       return llvm::None;
194 
195     // To get the size of the memref object in memory, the total size is the
196     // max(stride * dimension-size) computed for all dimensions times the size
197     // of the element.
198     auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
199     if (!elementSize)
200       return llvm::None;
201 
202     if (memRefType.getRank() == 0)
203       return elementSize;
204 
205     auto dims = memRefType.getShape();
206     if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
207         offset == MemRefType::getDynamicStrideOrOffset() ||
208         llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
209       return llvm::None;
210 
211     int64_t memrefSize = -1;
212     for (auto shape : enumerate(dims))
213       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
214 
215     return (offset + memrefSize) * elementSize.getValue();
216   }
217 
218   if (auto tensorType = type.dyn_cast<TensorType>()) {
219     if (!tensorType.hasStaticShape())
220       return llvm::None;
221 
222     auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
223     if (!elementSize)
224       return llvm::None;
225 
226     int64_t size = elementSize.getValue();
227     for (auto shape : tensorType.getShape())
228       size *= shape;
229 
230     return size;
231   }
232 
233   // TODO: Add size computation for other types.
234   return llvm::None;
235 }
236 
237 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
238 static Type convertScalarType(const spirv::TargetEnv &targetEnv,
239                               const SPIRVTypeConverter::Options &options,
240                               spirv::ScalarType type,
241                               Optional<spirv::StorageClass> storageClass = {}) {
242   // Get extension and capability requirements for the given type.
243   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
244   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
245   type.getExtensions(extensions, storageClass);
246   type.getCapabilities(capabilities, storageClass);
247 
248   // If all requirements are met, then we can accept this type as-is.
249   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
250       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
251     return type;
252 
253   // Otherwise we need to adjust the type, which really means adjusting the
254   // bitwidth given this is a scalar type.
255 
256   if (!options.emulateNon32BitScalarTypes)
257     return nullptr;
258 
259   if (auto floatType = type.dyn_cast<FloatType>()) {
260     LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
261     return Builder(targetEnv.getContext()).getF32Type();
262   }
263 
264   auto intType = type.cast<IntegerType>();
265   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
266   return IntegerType::get(targetEnv.getContext(), /*width=*/32,
267                           intType.getSignedness());
268 }
269 
270 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
271 static Type convertVectorType(const spirv::TargetEnv &targetEnv,
272                               const SPIRVTypeConverter::Options &options,
273                               VectorType type,
274                               Optional<spirv::StorageClass> storageClass = {}) {
275   if (type.getRank() == 1 && type.getNumElements() == 1)
276     return type.getElementType();
277 
278   if (!spirv::CompositeType::isValid(type)) {
279     // TODO: Vector types with more than four elements can be translated into
280     // array types.
281     LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
282     return nullptr;
283   }
284 
285   // Get extension and capability requirements for the given type.
286   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
287   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
288   type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
289   type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
290 
291   // If all requirements are met, then we can accept this type as-is.
292   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
293       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
294     return type;
295 
296   auto elementType = convertScalarType(
297       targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
298       storageClass);
299   if (elementType)
300     return VectorType::get(type.getShape(), elementType);
301   return nullptr;
302 }
303 
304 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
305 ///
306 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
307 /// create composite constants with OpConstantComposite to embed relative large
308 /// constant values and use OpCompositeExtract and OpCompositeInsert to
309 /// manipulate, like what we do for vectors.
310 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
311                               const SPIRVTypeConverter::Options &options,
312                               TensorType type) {
313   // TODO: Handle dynamic shapes.
314   if (!type.hasStaticShape()) {
315     LLVM_DEBUG(llvm::dbgs()
316                << type << " illegal: dynamic shape unimplemented\n");
317     return nullptr;
318   }
319 
320   auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
321   if (!scalarType) {
322     LLVM_DEBUG(llvm::dbgs()
323                << type << " illegal: cannot convert non-scalar element type\n");
324     return nullptr;
325   }
326 
327   Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
328   Optional<int64_t> tensorSize = getTypeNumBytes(options, type);
329   if (!scalarSize || !tensorSize) {
330     LLVM_DEBUG(llvm::dbgs()
331                << type << " illegal: cannot deduce element count\n");
332     return nullptr;
333   }
334 
335   auto arrayElemCount = *tensorSize / *scalarSize;
336   auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
337   if (!arrayElemType)
338     return nullptr;
339   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
340   if (!arrayElemSize) {
341     LLVM_DEBUG(llvm::dbgs()
342                << type << " illegal: cannot deduce converted element size\n");
343     return nullptr;
344   }
345 
346   return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
347 }
348 
349 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
350                                   const SPIRVTypeConverter::Options &options,
351                                   MemRefType type) {
352   if (!type.hasStaticShape()) {
353     LLVM_DEBUG(llvm::dbgs()
354                << type << " dynamic shape on i1 is not supported yet\n");
355     return nullptr;
356   }
357 
358   Optional<spirv::StorageClass> storageClass =
359       SPIRVTypeConverter::getStorageClassForMemorySpace(
360           type.getMemorySpaceAsInt());
361   if (!storageClass) {
362     LLVM_DEBUG(llvm::dbgs()
363                << type << " illegal: cannot convert memory space\n");
364     return nullptr;
365   }
366 
367   unsigned numBoolBits = options.boolNumBits;
368   if (numBoolBits != 8) {
369     LLVM_DEBUG(llvm::dbgs()
370                << "using non-8-bit storage for bool types unimplemented");
371     return nullptr;
372   }
373   auto elementType = IntegerType::get(type.getContext(), numBoolBits)
374                          .dyn_cast<spirv::ScalarType>();
375   if (!elementType)
376     return nullptr;
377   Type arrayElemType =
378       convertScalarType(targetEnv, options, elementType, storageClass);
379   if (!arrayElemType)
380     return nullptr;
381   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
382   if (!arrayElemSize) {
383     LLVM_DEBUG(llvm::dbgs()
384                << type << " illegal: cannot deduce converted element size\n");
385     return nullptr;
386   }
387 
388   int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
389   auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
390   auto arrayType =
391       spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
392 
393   // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
394   // workgroup storage class do not need the struct to be laid out explicitly.
395   auto structType = *storageClass == spirv::StorageClass::Workgroup
396                         ? spirv::StructType::get(arrayType)
397                         : spirv::StructType::get(arrayType, 0);
398   return spirv::PointerType::get(structType, *storageClass);
399 }
400 
401 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
402                               const SPIRVTypeConverter::Options &options,
403                               MemRefType type) {
404   if (type.getElementType().isa<IntegerType>() &&
405       type.getElementTypeBitWidth() == 1) {
406     return convertBoolMemrefType(targetEnv, options, type);
407   }
408 
409   Optional<spirv::StorageClass> storageClass =
410       SPIRVTypeConverter::getStorageClassForMemorySpace(
411           type.getMemorySpaceAsInt());
412   if (!storageClass) {
413     LLVM_DEBUG(llvm::dbgs()
414                << type << " illegal: cannot convert memory space\n");
415     return nullptr;
416   }
417 
418   Type arrayElemType;
419   Type elementType = type.getElementType();
420   if (auto vecType = elementType.dyn_cast<VectorType>()) {
421     arrayElemType =
422         convertVectorType(targetEnv, options, vecType, storageClass);
423   } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
424     arrayElemType =
425         convertScalarType(targetEnv, options, scalarType, storageClass);
426   } else {
427     LLVM_DEBUG(
428         llvm::dbgs()
429         << type
430         << " unhandled: can only convert scalar or vector element type\n");
431     return nullptr;
432   }
433   if (!arrayElemType)
434     return nullptr;
435 
436   Optional<int64_t> elementSize = getTypeNumBytes(options, elementType);
437   if (!elementSize) {
438     LLVM_DEBUG(llvm::dbgs()
439                << type << " illegal: cannot deduce element size\n");
440     return nullptr;
441   }
442 
443   if (!type.hasStaticShape()) {
444     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, *elementSize);
445     // Wrap in a struct to satisfy Vulkan interface requirements.
446     auto structType = spirv::StructType::get(arrayType, 0);
447     return spirv::PointerType::get(structType, *storageClass);
448   }
449 
450   Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
451   if (!memrefSize) {
452     LLVM_DEBUG(llvm::dbgs()
453                << type << " illegal: cannot deduce element count\n");
454     return nullptr;
455   }
456 
457   auto arrayElemCount = *memrefSize / *elementSize;
458 
459   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
460   if (!arrayElemSize) {
461     LLVM_DEBUG(llvm::dbgs()
462                << type << " illegal: cannot deduce converted element size\n");
463     return nullptr;
464   }
465 
466   auto arrayType =
467       spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
468 
469   // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
470   // workgroup storage class do not need the struct to be laid out explicitly.
471   auto structType = *storageClass == spirv::StorageClass::Workgroup
472                         ? spirv::StructType::get(arrayType)
473                         : spirv::StructType::get(arrayType, 0);
474   return spirv::PointerType::get(structType, *storageClass);
475 }
476 
477 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
478                                        Options options)
479     : targetEnv(targetAttr), options(options) {
480   // Add conversions. The order matters here: later ones will be tried earlier.
481 
482   // Allow all SPIR-V dialect specific types. This assumes all builtin types
483   // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
484   // were tried before.
485   //
486   // TODO: this assumes that the SPIR-V types are valid to use in
487   // the given target environment, which should be the case if the whole
488   // pipeline is driven by the same target environment. Still, we probably still
489   // want to validate and convert to be safe.
490   addConversion([](spirv::SPIRVType type) { return type; });
491 
492   addConversion([](IndexType indexType) {
493     return SPIRVTypeConverter::getIndexType(indexType.getContext());
494   });
495 
496   addConversion([this](IntegerType intType) -> Optional<Type> {
497     if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
498       return convertScalarType(this->targetEnv, this->options, scalarType);
499     return Type();
500   });
501 
502   addConversion([this](FloatType floatType) -> Optional<Type> {
503     if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
504       return convertScalarType(this->targetEnv, this->options, scalarType);
505     return Type();
506   });
507 
508   addConversion([this](VectorType vectorType) {
509     return convertVectorType(this->targetEnv, this->options, vectorType);
510   });
511 
512   addConversion([this](TensorType tensorType) {
513     return convertTensorType(this->targetEnv, this->options, tensorType);
514   });
515 
516   addConversion([this](MemRefType memRefType) {
517     return convertMemrefType(this->targetEnv, this->options, memRefType);
518   });
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // FuncOp Conversion Patterns
523 //===----------------------------------------------------------------------===//
524 
525 namespace {
526 /// A pattern for rewriting function signature to convert arguments of functions
527 /// to be of valid SPIR-V types.
528 class FuncOpConversion final : public OpConversionPattern<FuncOp> {
529 public:
530   using OpConversionPattern<FuncOp>::OpConversionPattern;
531 
532   LogicalResult
533   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
534                   ConversionPatternRewriter &rewriter) const override;
535 };
536 } // namespace
537 
538 LogicalResult
539 FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
540                                   ConversionPatternRewriter &rewriter) const {
541   auto fnType = funcOp.getType();
542   if (fnType.getNumResults() > 1)
543     return failure();
544 
545   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
546   for (auto argType : enumerate(fnType.getInputs())) {
547     auto convertedType = getTypeConverter()->convertType(argType.value());
548     if (!convertedType)
549       return failure();
550     signatureConverter.addInputs(argType.index(), convertedType);
551   }
552 
553   Type resultType;
554   if (fnType.getNumResults() == 1) {
555     resultType = getTypeConverter()->convertType(fnType.getResult(0));
556     if (!resultType)
557       return failure();
558   }
559 
560   // Create the converted spv.func op.
561   auto newFuncOp = rewriter.create<spirv::FuncOp>(
562       funcOp.getLoc(), funcOp.getName(),
563       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
564                                resultType ? TypeRange(resultType)
565                                           : TypeRange()));
566 
567   // Copy over all attributes other than the function name and type.
568   for (const auto &namedAttr : funcOp->getAttrs()) {
569     if (namedAttr.first != impl::getTypeAttrName() &&
570         namedAttr.first != SymbolTable::getSymbolAttrName())
571       newFuncOp->setAttr(namedAttr.first, namedAttr.second);
572   }
573 
574   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
575                               newFuncOp.end());
576   if (failed(rewriter.convertRegionTypes(
577           &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
578     return failure();
579   rewriter.eraseOp(funcOp);
580   return success();
581 }
582 
583 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
584                                               RewritePatternSet &patterns) {
585   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
586 }
587 
588 //===----------------------------------------------------------------------===//
589 // Builtin Variables
590 //===----------------------------------------------------------------------===//
591 
592 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
593                                                   spirv::BuiltIn builtin) {
594   // Look through all global variables in the given `body` block and check if
595   // there is a spv.GlobalVariable that has the same `builtin` attribute.
596   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
597     if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
598             spirv::SPIRVDialect::getAttributeName(
599                 spirv::Decoration::BuiltIn))) {
600       auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
601       if (varBuiltIn && varBuiltIn.getValue() == builtin) {
602         return varOp;
603       }
604     }
605   }
606   return nullptr;
607 }
608 
609 /// Gets name of global variable for a builtin.
610 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
611   return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
612 }
613 
614 /// Gets or inserts a global variable for a builtin within `body` block.
615 static spirv::GlobalVariableOp
616 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
617                            OpBuilder &builder) {
618   if (auto varOp = getBuiltinVariable(body, builtin))
619     return varOp;
620 
621   OpBuilder::InsertionGuard guard(builder);
622   builder.setInsertionPointToStart(&body);
623 
624   spirv::GlobalVariableOp newVarOp;
625   switch (builtin) {
626   case spirv::BuiltIn::NumWorkgroups:
627   case spirv::BuiltIn::WorkgroupSize:
628   case spirv::BuiltIn::WorkgroupId:
629   case spirv::BuiltIn::LocalInvocationId:
630   case spirv::BuiltIn::GlobalInvocationId: {
631     auto ptrType = spirv::PointerType::get(
632         VectorType::get({3}, builder.getIntegerType(32)),
633         spirv::StorageClass::Input);
634     std::string name = getBuiltinVarName(builtin);
635     newVarOp =
636         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
637     break;
638   }
639   case spirv::BuiltIn::SubgroupId:
640   case spirv::BuiltIn::NumSubgroups:
641   case spirv::BuiltIn::SubgroupSize: {
642     auto ptrType = spirv::PointerType::get(builder.getIntegerType(32),
643                                            spirv::StorageClass::Input);
644     std::string name = getBuiltinVarName(builtin);
645     newVarOp =
646         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
647     break;
648   }
649   default:
650     emitError(loc, "unimplemented builtin variable generation for ")
651         << stringifyBuiltIn(builtin);
652   }
653   return newVarOp;
654 }
655 
656 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
657                                            spirv::BuiltIn builtin,
658                                            OpBuilder &builder) {
659   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
660   if (!parent) {
661     op->emitError("expected operation to be within a module-like op");
662     return nullptr;
663   }
664 
665   spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(
666       *parent->getRegion(0).begin(), op->getLoc(), builtin, builder);
667   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
668   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // Push constant storage
673 //===----------------------------------------------------------------------===//
674 
675 /// Returns the pointer type for the push constant storage containing
676 /// `elementCount` 32-bit integer values.
677 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
678                                                      Builder &builder) {
679   auto arrayType = spirv::ArrayType::get(
680       SPIRVTypeConverter::getIndexType(builder.getContext()), elementCount,
681       /*stride=*/4);
682   auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
683   return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
684 }
685 
686 /// Returns the push constant varible containing `elementCount` 32-bit integer
687 /// values in `body`. Returns null op if such an op does not exit.
688 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
689                                                        unsigned elementCount) {
690   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
691     auto ptrType = varOp.type().cast<spirv::PointerType>();
692     // Note that Vulkan requires "There must be no more than one push constant
693     // block statically used per shader entry point." So we should always reuse
694     // the existing one.
695     if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
696       auto numElements = ptrType.getPointeeType()
697                              .cast<spirv::StructType>()
698                              .getElementType(0)
699                              .cast<spirv::ArrayType>()
700                              .getNumElements();
701       if (numElements == elementCount)
702         return varOp;
703     }
704   }
705   return nullptr;
706 }
707 
708 /// Gets or inserts a global variable for push constant storage containing
709 /// `elementCount` 32-bit integer values in `block`.
710 static spirv::GlobalVariableOp
711 getOrInsertPushConstantVariable(Location loc, Block &block,
712                                 unsigned elementCount, OpBuilder &b) {
713   if (auto varOp = getPushConstantVariable(block, elementCount))
714     return varOp;
715 
716   auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
717   auto type = getPushConstantStorageType(elementCount, builder);
718   const char *name = "__push_constant_var__";
719   return builder.create<spirv::GlobalVariableOp>(loc, type, name,
720                                                  /*initializer=*/nullptr);
721 }
722 
723 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
724                                   unsigned offset, OpBuilder &builder) {
725   Location loc = op->getLoc();
726   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
727   if (!parent) {
728     op->emitError("expected operation to be within a module-like op");
729     return nullptr;
730   }
731 
732   spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
733       loc, parent->getRegion(0).front(), elementCount, builder);
734 
735   auto i32Type = SPIRVTypeConverter::getIndexType(builder.getContext());
736   Value zeroOp = spirv::ConstantOp::getZero(i32Type, loc, builder);
737   Value offsetOp = builder.create<spirv::ConstantOp>(
738       loc, i32Type, builder.getI32IntegerAttr(offset));
739   auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
740   auto acOp = builder.create<spirv::AccessChainOp>(
741       loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
742   return builder.create<spirv::LoadOp>(loc, acOp);
743 }
744 
745 //===----------------------------------------------------------------------===//
746 // Index calculation
747 //===----------------------------------------------------------------------===//
748 
749 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
750                                   int64_t offset, Location loc,
751                                   OpBuilder &builder) {
752   assert(indices.size() == strides.size() &&
753          "must provide indices for all dimensions");
754 
755   auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext());
756 
757   // TODO: Consider moving to use affine.apply and patterns converting
758   // affine.apply to standard ops. This needs converting to SPIR-V passes to be
759   // broken down into progressive small steps so we can have intermediate steps
760   // using other dialects. At the moment SPIR-V is the final sink.
761 
762   Value linearizedIndex = builder.create<spirv::ConstantOp>(
763       loc, indexType, IntegerAttr::get(indexType, offset));
764   for (auto index : llvm::enumerate(indices)) {
765     Value strideVal = builder.create<spirv::ConstantOp>(
766         loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
767     Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
768     linearizedIndex =
769         builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
770   }
771   return linearizedIndex;
772 }
773 
774 spirv::AccessChainOp mlir::spirv::getElementPtr(
775     SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
776     ValueRange indices, Location loc, OpBuilder &builder) {
777   // Get base and offset of the MemRefType and verify they are static.
778 
779   int64_t offset;
780   SmallVector<int64_t, 4> strides;
781   if (failed(getStridesAndOffset(baseType, strides, offset)) ||
782       llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
783       offset == MemRefType::getDynamicStrideOrOffset()) {
784     return nullptr;
785   }
786 
787   auto indexType = typeConverter.getIndexType(builder.getContext());
788 
789   SmallVector<Value, 2> linearizedIndices;
790   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
791 
792   // Add a '0' at the start to index into the struct.
793   linearizedIndices.push_back(zero);
794 
795   if (baseType.getRank() == 0) {
796     linearizedIndices.push_back(zero);
797   } else {
798     linearizedIndices.push_back(
799         linearizeIndex(indices, strides, offset, loc, builder));
800   }
801   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
802 }
803 
804 //===----------------------------------------------------------------------===//
805 // SPIR-V ConversionTarget
806 //===----------------------------------------------------------------------===//
807 
808 std::unique_ptr<SPIRVConversionTarget>
809 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
810   std::unique_ptr<SPIRVConversionTarget> target(
811       // std::make_unique does not work here because the constructor is private.
812       new SPIRVConversionTarget(targetAttr));
813   SPIRVConversionTarget *targetPtr = target.get();
814   target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
815       // We need to capture the raw pointer here because it is stable:
816       // target will be destroyed once this function is returned.
817       [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
818   return target;
819 }
820 
821 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
822     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
823 
824 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
825   // Make sure this op is available at the given version. Ops not implementing
826   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
827   // SPIR-V versions.
828   if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
829     if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
830       LLVM_DEBUG(llvm::dbgs()
831                  << op->getName() << " illegal: requiring min version "
832                  << spirv::stringifyVersion(minVersion.getMinVersion())
833                  << "\n");
834       return false;
835     }
836   if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
837     if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
838       LLVM_DEBUG(llvm::dbgs()
839                  << op->getName() << " illegal: requiring max version "
840                  << spirv::stringifyVersion(maxVersion.getMaxVersion())
841                  << "\n");
842       return false;
843     }
844 
845   // Make sure this op's required extensions are allowed to use. Ops not
846   // implementing QueryExtensionInterface do not require extensions to be
847   // available.
848   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
849     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
850                                           extensions.getExtensions())))
851       return false;
852 
853   // Make sure this op's required extensions are allowed to use. Ops not
854   // implementing QueryCapabilityInterface do not require capabilities to be
855   // available.
856   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
857     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
858                                            capabilities.getCapabilities())))
859       return false;
860 
861   SmallVector<Type, 4> valueTypes;
862   valueTypes.append(op->operand_type_begin(), op->operand_type_end());
863   valueTypes.append(op->result_type_begin(), op->result_type_end());
864 
865   // Special treatment for global variables, whose type requirements are
866   // conveyed by type attributes.
867   if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
868     valueTypes.push_back(globalVar.type());
869 
870   // Make sure the op's operands/results use types that are allowed by the
871   // target environment.
872   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
873   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
874   for (Type valueType : valueTypes) {
875     typeExtensions.clear();
876     valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
877     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
878                                           typeExtensions)))
879       return false;
880 
881     typeCapabilities.clear();
882     valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
883     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
884                                            typeCapabilities)))
885       return false;
886   }
887 
888   return true;
889 }
890