1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <functional>
22 
23 #define DEBUG_TYPE "mlir-spirv-conversion"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Checks that `candidates` extension requirements are possible to be satisfied
32 /// with the given `targetEnv`.
33 ///
34 ///  `candidates` is a vector of vector for extension requirements following
35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
36 /// convention.
37 template <typename LabelT>
38 static LogicalResult checkExtensionRequirements(
39     LabelT label, const spirv::TargetEnv &targetEnv,
40     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
41   for (const auto &ors : candidates) {
42     if (targetEnv.allows(ors))
43       continue;
44 
45     LLVM_DEBUG({
46       SmallVector<StringRef> extStrings;
47       for (spirv::Extension ext : ors)
48         extStrings.push_back(spirv::stringifyExtension(ext));
49 
50       llvm::dbgs() << label << " illegal: requires at least one extension in ["
51                    << llvm::join(extStrings, ", ")
52                    << "] but none allowed in target environment\n";
53     });
54     return failure();
55   }
56   return success();
57 }
58 
59 /// Checks that `candidates`capability requirements are possible to be satisfied
60 /// with the given `isAllowedFn`.
61 ///
62 ///  `candidates` is a vector of vector for capability requirements following
63 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
64 /// convention.
65 template <typename LabelT>
66 static LogicalResult checkCapabilityRequirements(
67     LabelT label, const spirv::TargetEnv &targetEnv,
68     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
69   for (const auto &ors : candidates) {
70     if (targetEnv.allows(ors))
71       continue;
72 
73     LLVM_DEBUG({
74       SmallVector<StringRef> capStrings;
75       for (spirv::Capability cap : ors)
76         capStrings.push_back(spirv::stringifyCapability(cap));
77 
78       llvm::dbgs() << label << " illegal: requires at least one capability in ["
79                    << llvm::join(capStrings, ", ")
80                    << "] but none allowed in target environment\n";
81     });
82     return failure();
83   }
84   return success();
85 }
86 
87 /// Returns true if the given `storageClass` needs explicit layout when used in
88 /// Shader environments.
89 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
90   switch (storageClass) {
91   case spirv::StorageClass::PhysicalStorageBuffer:
92   case spirv::StorageClass::PushConstant:
93   case spirv::StorageClass::StorageBuffer:
94   case spirv::StorageClass::Uniform:
95     return true;
96   default:
97     return false;
98   }
99 }
100 
101 /// Wraps the given `elementType` in a struct and gets the pointer to the
102 /// struct. This is used to satisfy Vulkan interface requirements.
103 static spirv::PointerType
104 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
105   auto structType = needsExplicitLayout(storageClass)
106                         ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
107                         : spirv::StructType::get(elementType);
108   return spirv::PointerType::get(structType, storageClass);
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Type Conversion
113 //===----------------------------------------------------------------------===//
114 
115 Type SPIRVTypeConverter::getIndexType() const {
116   return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
117 }
118 
119 /// Mapping between SPIR-V storage classes to memref memory spaces.
120 ///
121 /// Note: memref does not have a defined semantics for each memory space; it
122 /// depends on the context where it is used. There are no particular reasons
123 /// behind the number assignments; we try to follow NVVM conventions and largely
124 /// give common storage classes a smaller number. The hope is use symbolic
125 /// memory space representation eventually after memref supports it.
126 // TODO: swap Generic and StorageBuffer assignment to be more akin
127 // to NVVM.
128 #define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
129   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
130   MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
131   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
132   MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
133   MAP_FN(spirv::StorageClass::Private, 5)                                      \
134   MAP_FN(spirv::StorageClass::Function, 6)                                     \
135   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
136   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
137   MAP_FN(spirv::StorageClass::Input, 9)                                        \
138   MAP_FN(spirv::StorageClass::Output, 10)                                      \
139   MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
140   MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
141   MAP_FN(spirv::StorageClass::Image, 13)                                       \
142   MAP_FN(spirv::StorageClass::CallableDataNV, 14)                              \
143   MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15)                      \
144   MAP_FN(spirv::StorageClass::RayPayloadNV, 16)                                \
145   MAP_FN(spirv::StorageClass::HitAttributeNV, 17)                              \
146   MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18)                        \
147   MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19)                        \
148   MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)
149 
150 unsigned
151 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
152 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
153   case storage:                                                                \
154     return space;
155 
156   switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
157 #undef STORAGE_SPACE_MAP_FN
158   llvm_unreachable("unhandled storage class!");
159 }
160 
161 Optional<spirv::StorageClass>
162 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
163 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
164   case space:                                                                  \
165     return storage;
166 
167   switch (space) {
168     STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
169   default:
170     return llvm::None;
171   }
172 #undef STORAGE_SPACE_MAP_FN
173 }
174 
175 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
176   return options;
177 }
178 
179 MLIRContext *SPIRVTypeConverter::getContext() const {
180   return targetEnv.getAttr().getContext();
181 }
182 
183 #undef STORAGE_SPACE_MAP_LIST
184 
185 // TODO: This is a utility function that should probably be exposed by the
186 // SPIR-V dialect. Keeping it local till the use case arises.
187 static Optional<int64_t>
188 getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
189   if (type.isa<spirv::ScalarType>()) {
190     auto bitWidth = type.getIntOrFloatBitWidth();
191     // According to the SPIR-V spec:
192     // "There is no physical size or bit pattern defined for values with boolean
193     // type. If they are stored (in conjunction with OpVariable), they can only
194     // be used with logical addressing operations, not physical, and only with
195     // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
196     // Private, Function, Input, and Output."
197     if (bitWidth == 1)
198       return llvm::None;
199     return bitWidth / 8;
200   }
201 
202   if (auto vecType = type.dyn_cast<VectorType>()) {
203     auto elementSize = getTypeNumBytes(options, vecType.getElementType());
204     if (!elementSize)
205       return llvm::None;
206     return vecType.getNumElements() * elementSize.getValue();
207   }
208 
209   if (auto memRefType = type.dyn_cast<MemRefType>()) {
210     // TODO: Layout should also be controlled by the ABI attributes. For now
211     // using the layout from MemRef.
212     int64_t offset;
213     SmallVector<int64_t, 4> strides;
214     if (!memRefType.hasStaticShape() ||
215         failed(getStridesAndOffset(memRefType, strides, offset)))
216       return llvm::None;
217 
218     // To get the size of the memref object in memory, the total size is the
219     // max(stride * dimension-size) computed for all dimensions times the size
220     // of the element.
221     auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
222     if (!elementSize)
223       return llvm::None;
224 
225     if (memRefType.getRank() == 0)
226       return elementSize;
227 
228     auto dims = memRefType.getShape();
229     if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
230         offset == MemRefType::getDynamicStrideOrOffset() ||
231         llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
232       return llvm::None;
233 
234     int64_t memrefSize = -1;
235     for (auto shape : enumerate(dims))
236       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
237 
238     return (offset + memrefSize) * elementSize.getValue();
239   }
240 
241   if (auto tensorType = type.dyn_cast<TensorType>()) {
242     if (!tensorType.hasStaticShape())
243       return llvm::None;
244 
245     auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
246     if (!elementSize)
247       return llvm::None;
248 
249     int64_t size = elementSize.getValue();
250     for (auto shape : tensorType.getShape())
251       size *= shape;
252 
253     return size;
254   }
255 
256   // TODO: Add size computation for other types.
257   return llvm::None;
258 }
259 
260 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
261 static Type convertScalarType(const spirv::TargetEnv &targetEnv,
262                               const SPIRVTypeConverter::Options &options,
263                               spirv::ScalarType type,
264                               Optional<spirv::StorageClass> storageClass = {}) {
265   // Get extension and capability requirements for the given type.
266   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
267   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
268   type.getExtensions(extensions, storageClass);
269   type.getCapabilities(capabilities, storageClass);
270 
271   // If all requirements are met, then we can accept this type as-is.
272   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
273       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
274     return type;
275 
276   // Otherwise we need to adjust the type, which really means adjusting the
277   // bitwidth given this is a scalar type.
278 
279   if (!options.emulateNon32BitScalarTypes)
280     return nullptr;
281 
282   if (auto floatType = type.dyn_cast<FloatType>()) {
283     LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
284     return Builder(targetEnv.getContext()).getF32Type();
285   }
286 
287   auto intType = type.cast<IntegerType>();
288   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
289   return IntegerType::get(targetEnv.getContext(), /*width=*/32,
290                           intType.getSignedness());
291 }
292 
293 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
294 static Type convertVectorType(const spirv::TargetEnv &targetEnv,
295                               const SPIRVTypeConverter::Options &options,
296                               VectorType type,
297                               Optional<spirv::StorageClass> storageClass = {}) {
298   if (type.getRank() == 1 && type.getNumElements() == 1)
299     return type.getElementType();
300 
301   if (!spirv::CompositeType::isValid(type)) {
302     // TODO: Vector types with more than four elements can be translated into
303     // array types.
304     LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
305     return nullptr;
306   }
307 
308   // Get extension and capability requirements for the given type.
309   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
310   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
311   type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
312   type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
313 
314   // If all requirements are met, then we can accept this type as-is.
315   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
316       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
317     return type;
318 
319   auto elementType = convertScalarType(
320       targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
321       storageClass);
322   if (elementType)
323     return VectorType::get(type.getShape(), elementType);
324   return nullptr;
325 }
326 
327 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
328 ///
329 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
330 /// create composite constants with OpConstantComposite to embed relative large
331 /// constant values and use OpCompositeExtract and OpCompositeInsert to
332 /// manipulate, like what we do for vectors.
333 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
334                               const SPIRVTypeConverter::Options &options,
335                               TensorType type) {
336   // TODO: Handle dynamic shapes.
337   if (!type.hasStaticShape()) {
338     LLVM_DEBUG(llvm::dbgs()
339                << type << " illegal: dynamic shape unimplemented\n");
340     return nullptr;
341   }
342 
343   auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
344   if (!scalarType) {
345     LLVM_DEBUG(llvm::dbgs()
346                << type << " illegal: cannot convert non-scalar element type\n");
347     return nullptr;
348   }
349 
350   Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
351   Optional<int64_t> tensorSize = getTypeNumBytes(options, type);
352   if (!scalarSize || !tensorSize) {
353     LLVM_DEBUG(llvm::dbgs()
354                << type << " illegal: cannot deduce element count\n");
355     return nullptr;
356   }
357 
358   auto arrayElemCount = *tensorSize / *scalarSize;
359   auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
360   if (!arrayElemType)
361     return nullptr;
362   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
363   if (!arrayElemSize) {
364     LLVM_DEBUG(llvm::dbgs()
365                << type << " illegal: cannot deduce converted element size\n");
366     return nullptr;
367   }
368 
369   return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
370 }
371 
372 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
373                                   const SPIRVTypeConverter::Options &options,
374                                   MemRefType type) {
375   Optional<spirv::StorageClass> storageClass =
376       SPIRVTypeConverter::getStorageClassForMemorySpace(
377           type.getMemorySpaceAsInt());
378   if (!storageClass) {
379     LLVM_DEBUG(llvm::dbgs()
380                << type << " illegal: cannot convert memory space\n");
381     return nullptr;
382   }
383 
384   unsigned numBoolBits = options.boolNumBits;
385   if (numBoolBits != 8) {
386     LLVM_DEBUG(llvm::dbgs()
387                << "using non-8-bit storage for bool types unimplemented");
388     return nullptr;
389   }
390   auto elementType = IntegerType::get(type.getContext(), numBoolBits)
391                          .dyn_cast<spirv::ScalarType>();
392   if (!elementType)
393     return nullptr;
394   Type arrayElemType =
395       convertScalarType(targetEnv, options, elementType, storageClass);
396   if (!arrayElemType)
397     return nullptr;
398   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
399   if (!arrayElemSize) {
400     LLVM_DEBUG(llvm::dbgs()
401                << type << " illegal: cannot deduce converted element size\n");
402     return nullptr;
403   }
404 
405   if (!type.hasStaticShape()) {
406     auto arrayType =
407         spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
408     return wrapInStructAndGetPointer(arrayType, *storageClass);
409   }
410 
411   int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
412   auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
413   auto arrayType =
414       spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
415 
416   return wrapInStructAndGetPointer(arrayType, *storageClass);
417 }
418 
419 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
420                               const SPIRVTypeConverter::Options &options,
421                               MemRefType type) {
422   if (type.getElementType().isa<IntegerType>() &&
423       type.getElementTypeBitWidth() == 1) {
424     return convertBoolMemrefType(targetEnv, options, type);
425   }
426 
427   Optional<spirv::StorageClass> storageClass =
428       SPIRVTypeConverter::getStorageClassForMemorySpace(
429           type.getMemorySpaceAsInt());
430   if (!storageClass) {
431     LLVM_DEBUG(llvm::dbgs()
432                << type << " illegal: cannot convert memory space\n");
433     return nullptr;
434   }
435 
436   Type arrayElemType;
437   Type elementType = type.getElementType();
438   if (auto vecType = elementType.dyn_cast<VectorType>()) {
439     arrayElemType =
440         convertVectorType(targetEnv, options, vecType, storageClass);
441   } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
442     arrayElemType =
443         convertScalarType(targetEnv, options, scalarType, storageClass);
444   } else {
445     LLVM_DEBUG(
446         llvm::dbgs()
447         << type
448         << " unhandled: can only convert scalar or vector element type\n");
449     return nullptr;
450   }
451   if (!arrayElemType)
452     return nullptr;
453 
454   Optional<int64_t> elementSize = getTypeNumBytes(options, elementType);
455   if (!elementSize) {
456     LLVM_DEBUG(llvm::dbgs()
457                << type << " illegal: cannot deduce element size\n");
458     return nullptr;
459   }
460 
461   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
462   if (!arrayElemSize) {
463     LLVM_DEBUG(llvm::dbgs()
464                << type << " illegal: cannot deduce converted element size\n");
465     return nullptr;
466   }
467 
468   if (!type.hasStaticShape()) {
469     auto arrayType =
470         spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
471     return wrapInStructAndGetPointer(arrayType, *storageClass);
472   }
473 
474   Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
475   if (!memrefSize) {
476     LLVM_DEBUG(llvm::dbgs()
477                << type << " illegal: cannot deduce element count\n");
478     return nullptr;
479   }
480 
481   auto arrayElemCount = *memrefSize / *elementSize;
482 
483 
484   auto arrayType =
485       spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
486 
487   return wrapInStructAndGetPointer(arrayType, *storageClass);
488 }
489 
490 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
491                                        Options options)
492     : targetEnv(targetAttr), options(options) {
493   // Add conversions. The order matters here: later ones will be tried earlier.
494 
495   // Allow all SPIR-V dialect specific types. This assumes all builtin types
496   // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
497   // were tried before.
498   //
499   // TODO: this assumes that the SPIR-V types are valid to use in
500   // the given target environment, which should be the case if the whole
501   // pipeline is driven by the same target environment. Still, we probably still
502   // want to validate and convert to be safe.
503   addConversion([](spirv::SPIRVType type) { return type; });
504 
505   addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
506 
507   addConversion([this](IntegerType intType) -> Optional<Type> {
508     if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
509       return convertScalarType(this->targetEnv, this->options, scalarType);
510     return Type();
511   });
512 
513   addConversion([this](FloatType floatType) -> Optional<Type> {
514     if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
515       return convertScalarType(this->targetEnv, this->options, scalarType);
516     return Type();
517   });
518 
519   addConversion([this](VectorType vectorType) {
520     return convertVectorType(this->targetEnv, this->options, vectorType);
521   });
522 
523   addConversion([this](TensorType tensorType) {
524     return convertTensorType(this->targetEnv, this->options, tensorType);
525   });
526 
527   addConversion([this](MemRefType memRefType) {
528     return convertMemrefType(this->targetEnv, this->options, memRefType);
529   });
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // FuncOp Conversion Patterns
534 //===----------------------------------------------------------------------===//
535 
536 namespace {
537 /// A pattern for rewriting function signature to convert arguments of functions
538 /// to be of valid SPIR-V types.
539 class FuncOpConversion final : public OpConversionPattern<FuncOp> {
540 public:
541   using OpConversionPattern<FuncOp>::OpConversionPattern;
542 
543   LogicalResult
544   matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
545                   ConversionPatternRewriter &rewriter) const override;
546 };
547 } // namespace
548 
549 LogicalResult
550 FuncOpConversion::matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
551                                   ConversionPatternRewriter &rewriter) const {
552   auto fnType = funcOp.getType();
553   if (fnType.getNumResults() > 1)
554     return failure();
555 
556   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
557   for (auto argType : enumerate(fnType.getInputs())) {
558     auto convertedType = getTypeConverter()->convertType(argType.value());
559     if (!convertedType)
560       return failure();
561     signatureConverter.addInputs(argType.index(), convertedType);
562   }
563 
564   Type resultType;
565   if (fnType.getNumResults() == 1) {
566     resultType = getTypeConverter()->convertType(fnType.getResult(0));
567     if (!resultType)
568       return failure();
569   }
570 
571   // Create the converted spv.func op.
572   auto newFuncOp = rewriter.create<spirv::FuncOp>(
573       funcOp.getLoc(), funcOp.getName(),
574       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
575                                resultType ? TypeRange(resultType)
576                                           : TypeRange()));
577 
578   // Copy over all attributes other than the function name and type.
579   for (const auto &namedAttr : funcOp->getAttrs()) {
580     if (namedAttr.first != function_like_impl::getTypeAttrName() &&
581         namedAttr.first != SymbolTable::getSymbolAttrName())
582       newFuncOp->setAttr(namedAttr.first, namedAttr.second);
583   }
584 
585   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
586                               newFuncOp.end());
587   if (failed(rewriter.convertRegionTypes(
588           &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
589     return failure();
590   rewriter.eraseOp(funcOp);
591   return success();
592 }
593 
594 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
595                                               RewritePatternSet &patterns) {
596   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
597 }
598 
599 //===----------------------------------------------------------------------===//
600 // Builtin Variables
601 //===----------------------------------------------------------------------===//
602 
603 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
604                                                   spirv::BuiltIn builtin) {
605   // Look through all global variables in the given `body` block and check if
606   // there is a spv.GlobalVariable that has the same `builtin` attribute.
607   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
608     if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
609             spirv::SPIRVDialect::getAttributeName(
610                 spirv::Decoration::BuiltIn))) {
611       auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
612       if (varBuiltIn && varBuiltIn.getValue() == builtin) {
613         return varOp;
614       }
615     }
616   }
617   return nullptr;
618 }
619 
620 /// Gets name of global variable for a builtin.
621 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
622   return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
623 }
624 
625 /// Gets or inserts a global variable for a builtin within `body` block.
626 static spirv::GlobalVariableOp
627 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
628                            Type integerType, OpBuilder &builder) {
629   if (auto varOp = getBuiltinVariable(body, builtin))
630     return varOp;
631 
632   OpBuilder::InsertionGuard guard(builder);
633   builder.setInsertionPointToStart(&body);
634 
635   spirv::GlobalVariableOp newVarOp;
636   switch (builtin) {
637   case spirv::BuiltIn::NumWorkgroups:
638   case spirv::BuiltIn::WorkgroupSize:
639   case spirv::BuiltIn::WorkgroupId:
640   case spirv::BuiltIn::LocalInvocationId:
641   case spirv::BuiltIn::GlobalInvocationId: {
642     auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
643                                            spirv::StorageClass::Input);
644     std::string name = getBuiltinVarName(builtin);
645     newVarOp =
646         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
647     break;
648   }
649   case spirv::BuiltIn::SubgroupId:
650   case spirv::BuiltIn::NumSubgroups:
651   case spirv::BuiltIn::SubgroupSize: {
652     auto ptrType =
653         spirv::PointerType::get(integerType, spirv::StorageClass::Input);
654     std::string name = getBuiltinVarName(builtin);
655     newVarOp =
656         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
657     break;
658   }
659   default:
660     emitError(loc, "unimplemented builtin variable generation for ")
661         << stringifyBuiltIn(builtin);
662   }
663   return newVarOp;
664 }
665 
666 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
667                                            spirv::BuiltIn builtin,
668                                            Type integerType,
669                                            OpBuilder &builder) {
670   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
671   if (!parent) {
672     op->emitError("expected operation to be within a module-like op");
673     return nullptr;
674   }
675 
676   spirv::GlobalVariableOp varOp =
677       getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
678                                  builtin, integerType, builder);
679   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
680   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
681 }
682 
683 //===----------------------------------------------------------------------===//
684 // Push constant storage
685 //===----------------------------------------------------------------------===//
686 
687 /// Returns the pointer type for the push constant storage containing
688 /// `elementCount` 32-bit integer values.
689 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
690                                                      Builder &builder,
691                                                      Type indexType) {
692   auto arrayType = spirv::ArrayType::get(indexType, elementCount,
693                                          /*stride=*/4);
694   auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
695   return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
696 }
697 
698 /// Returns the push constant varible containing `elementCount` 32-bit integer
699 /// values in `body`. Returns null op if such an op does not exit.
700 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
701                                                        unsigned elementCount) {
702   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
703     auto ptrType = varOp.type().dyn_cast<spirv::PointerType>();
704     if (!ptrType)
705       continue;
706 
707     // Note that Vulkan requires "There must be no more than one push constant
708     // block statically used per shader entry point." So we should always reuse
709     // the existing one.
710     if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
711       auto numElements = ptrType.getPointeeType()
712                              .cast<spirv::StructType>()
713                              .getElementType(0)
714                              .cast<spirv::ArrayType>()
715                              .getNumElements();
716       if (numElements == elementCount)
717         return varOp;
718     }
719   }
720   return nullptr;
721 }
722 
723 /// Gets or inserts a global variable for push constant storage containing
724 /// `elementCount` 32-bit integer values in `block`.
725 static spirv::GlobalVariableOp
726 getOrInsertPushConstantVariable(Location loc, Block &block,
727                                 unsigned elementCount, OpBuilder &b,
728                                 Type indexType) {
729   if (auto varOp = getPushConstantVariable(block, elementCount))
730     return varOp;
731 
732   auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
733   auto type = getPushConstantStorageType(elementCount, builder, indexType);
734   const char *name = "__push_constant_var__";
735   return builder.create<spirv::GlobalVariableOp>(loc, type, name,
736                                                  /*initializer=*/nullptr);
737 }
738 
739 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
740                                   unsigned offset, Type integerType,
741                                   OpBuilder &builder) {
742   Location loc = op->getLoc();
743   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
744   if (!parent) {
745     op->emitError("expected operation to be within a module-like op");
746     return nullptr;
747   }
748 
749   spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
750       loc, parent->getRegion(0).front(), elementCount, builder, integerType);
751 
752   Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
753   Value offsetOp = builder.create<spirv::ConstantOp>(
754       loc, integerType, builder.getI32IntegerAttr(offset));
755   auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
756   auto acOp = builder.create<spirv::AccessChainOp>(
757       loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
758   return builder.create<spirv::LoadOp>(loc, acOp);
759 }
760 
761 //===----------------------------------------------------------------------===//
762 // Index calculation
763 //===----------------------------------------------------------------------===//
764 
765 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
766                                   int64_t offset, Type integerType,
767                                   Location loc, OpBuilder &builder) {
768   assert(indices.size() == strides.size() &&
769          "must provide indices for all dimensions");
770 
771   // TODO: Consider moving to use affine.apply and patterns converting
772   // affine.apply to standard ops. This needs converting to SPIR-V passes to be
773   // broken down into progressive small steps so we can have intermediate steps
774   // using other dialects. At the moment SPIR-V is the final sink.
775 
776   Value linearizedIndex = builder.create<spirv::ConstantOp>(
777       loc, integerType, IntegerAttr::get(integerType, offset));
778   for (auto index : llvm::enumerate(indices)) {
779     Value strideVal = builder.create<spirv::ConstantOp>(
780         loc, integerType,
781         IntegerAttr::get(integerType, strides[index.index()]));
782     Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
783     linearizedIndex =
784         builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
785   }
786   return linearizedIndex;
787 }
788 
789 spirv::AccessChainOp mlir::spirv::getElementPtr(
790     SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
791     ValueRange indices, Location loc, OpBuilder &builder) {
792   // Get base and offset of the MemRefType and verify they are static.
793 
794   int64_t offset;
795   SmallVector<int64_t, 4> strides;
796   if (failed(getStridesAndOffset(baseType, strides, offset)) ||
797       llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
798       offset == MemRefType::getDynamicStrideOrOffset()) {
799     return nullptr;
800   }
801 
802   auto indexType = typeConverter.getIndexType();
803 
804   SmallVector<Value, 2> linearizedIndices;
805   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
806 
807   // Add a '0' at the start to index into the struct.
808   linearizedIndices.push_back(zero);
809 
810   if (baseType.getRank() == 0) {
811     linearizedIndices.push_back(zero);
812   } else {
813     linearizedIndices.push_back(
814         linearizeIndex(indices, strides, offset, indexType, loc, builder));
815   }
816   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
817 }
818 
819 //===----------------------------------------------------------------------===//
820 // SPIR-V ConversionTarget
821 //===----------------------------------------------------------------------===//
822 
823 std::unique_ptr<SPIRVConversionTarget>
824 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
825   std::unique_ptr<SPIRVConversionTarget> target(
826       // std::make_unique does not work here because the constructor is private.
827       new SPIRVConversionTarget(targetAttr));
828   SPIRVConversionTarget *targetPtr = target.get();
829   target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
830       // We need to capture the raw pointer here because it is stable:
831       // target will be destroyed once this function is returned.
832       [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
833   return target;
834 }
835 
836 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
837     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
838 
839 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
840   // Make sure this op is available at the given version. Ops not implementing
841   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
842   // SPIR-V versions.
843   if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
844     if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
845       LLVM_DEBUG(llvm::dbgs()
846                  << op->getName() << " illegal: requiring min version "
847                  << spirv::stringifyVersion(minVersion.getMinVersion())
848                  << "\n");
849       return false;
850     }
851   if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
852     if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
853       LLVM_DEBUG(llvm::dbgs()
854                  << op->getName() << " illegal: requiring max version "
855                  << spirv::stringifyVersion(maxVersion.getMaxVersion())
856                  << "\n");
857       return false;
858     }
859 
860   // Make sure this op's required extensions are allowed to use. Ops not
861   // implementing QueryExtensionInterface do not require extensions to be
862   // available.
863   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
864     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
865                                           extensions.getExtensions())))
866       return false;
867 
868   // Make sure this op's required extensions are allowed to use. Ops not
869   // implementing QueryCapabilityInterface do not require capabilities to be
870   // available.
871   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
872     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
873                                            capabilities.getCapabilities())))
874       return false;
875 
876   SmallVector<Type, 4> valueTypes;
877   valueTypes.append(op->operand_type_begin(), op->operand_type_end());
878   valueTypes.append(op->result_type_begin(), op->result_type_end());
879 
880   // Special treatment for global variables, whose type requirements are
881   // conveyed by type attributes.
882   if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
883     valueTypes.push_back(globalVar.type());
884 
885   // Make sure the op's operands/results use types that are allowed by the
886   // target environment.
887   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
888   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
889   for (Type valueType : valueTypes) {
890     typeExtensions.clear();
891     valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
892     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
893                                           typeExtensions)))
894       return false;
895 
896     typeCapabilities.clear();
897     valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
898     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
899                                            typeCapabilities)))
900       return false;
901   }
902 
903   return true;
904 }
905