1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 #include "llvm/ADT/Sequence.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Debug.h"
20 
21 #include <functional>
22 
23 #define DEBUG_TYPE "mlir-spirv-conversion"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Checks that `candidates` extension requirements are possible to be satisfied
32 /// with the given `targetEnv`.
33 ///
34 ///  `candidates` is a vector of vector for extension requirements following
35 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
36 /// convention.
37 template <typename LabelT>
38 static LogicalResult checkExtensionRequirements(
39     LabelT label, const spirv::TargetEnv &targetEnv,
40     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
41   for (const auto &ors : candidates) {
42     if (targetEnv.allows(ors))
43       continue;
44 
45     LLVM_DEBUG({
46       SmallVector<StringRef> extStrings;
47       for (spirv::Extension ext : ors)
48         extStrings.push_back(spirv::stringifyExtension(ext));
49 
50       llvm::dbgs() << label << " illegal: requires at least one extension in ["
51                    << llvm::join(extStrings, ", ")
52                    << "] but none allowed in target environment\n";
53     });
54     return failure();
55   }
56   return success();
57 }
58 
59 /// Checks that `candidates`capability requirements are possible to be satisfied
60 /// with the given `isAllowedFn`.
61 ///
62 ///  `candidates` is a vector of vector for capability requirements following
63 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
64 /// convention.
65 template <typename LabelT>
66 static LogicalResult checkCapabilityRequirements(
67     LabelT label, const spirv::TargetEnv &targetEnv,
68     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
69   for (const auto &ors : candidates) {
70     if (targetEnv.allows(ors))
71       continue;
72 
73     LLVM_DEBUG({
74       SmallVector<StringRef> capStrings;
75       for (spirv::Capability cap : ors)
76         capStrings.push_back(spirv::stringifyCapability(cap));
77 
78       llvm::dbgs() << label << " illegal: requires at least one capability in ["
79                    << llvm::join(capStrings, ", ")
80                    << "] but none allowed in target environment\n";
81     });
82     return failure();
83   }
84   return success();
85 }
86 
87 /// Returns true if the given `storageClass` needs explicit layout when used in
88 /// Shader environments.
89 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
90   switch (storageClass) {
91   case spirv::StorageClass::PhysicalStorageBuffer:
92   case spirv::StorageClass::PushConstant:
93   case spirv::StorageClass::StorageBuffer:
94   case spirv::StorageClass::Uniform:
95     return true;
96   default:
97     return false;
98   }
99 }
100 
101 /// Wraps the given `elementType` in a struct and gets the pointer to the
102 /// struct. This is used to satisfy Vulkan interface requirements.
103 static spirv::PointerType
104 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
105   auto structType = needsExplicitLayout(storageClass)
106                         ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
107                         : spirv::StructType::get(elementType);
108   return spirv::PointerType::get(structType, storageClass);
109 }
110 
111 //===----------------------------------------------------------------------===//
112 // Type Conversion
113 //===----------------------------------------------------------------------===//
114 
115 Type SPIRVTypeConverter::getIndexType() const {
116   return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
117 }
118 
119 /// Mapping between SPIR-V storage classes to memref memory spaces.
120 ///
121 /// Note: memref does not have a defined semantics for each memory space; it
122 /// depends on the context where it is used. There are no particular reasons
123 /// behind the number assignments; we try to follow NVVM conventions and largely
124 /// give common storage classes a smaller number. The hope is use symbolic
125 /// memory space representation eventually after memref supports it.
126 // TODO: swap Generic and StorageBuffer assignment to be more akin
127 // to NVVM.
128 #define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
129   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
130   MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
131   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
132   MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
133   MAP_FN(spirv::StorageClass::Private, 5)                                      \
134   MAP_FN(spirv::StorageClass::Function, 6)                                     \
135   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
136   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
137   MAP_FN(spirv::StorageClass::Input, 9)                                        \
138   MAP_FN(spirv::StorageClass::Output, 10)                                      \
139   MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
140   MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
141   MAP_FN(spirv::StorageClass::Image, 13)                                       \
142   MAP_FN(spirv::StorageClass::CallableDataKHR, 14)                             \
143   MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15)                     \
144   MAP_FN(spirv::StorageClass::RayPayloadKHR, 16)                               \
145   MAP_FN(spirv::StorageClass::HitAttributeKHR, 17)                             \
146   MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18)                       \
147   MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19)                       \
148   MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)                       \
149   MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21)                            \
150   MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22)                             \
151   MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
152 
153 unsigned
154 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
155 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
156   case storage:                                                                \
157     return space;
158 
159   switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
160 #undef STORAGE_SPACE_MAP_FN
161   llvm_unreachable("unhandled storage class!");
162 }
163 
164 Optional<spirv::StorageClass>
165 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
166 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
167   case space:                                                                  \
168     return storage;
169 
170   switch (space) {
171     STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
172   default:
173     return llvm::None;
174   }
175 #undef STORAGE_SPACE_MAP_FN
176 }
177 
178 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
179   return options;
180 }
181 
182 MLIRContext *SPIRVTypeConverter::getContext() const {
183   return targetEnv.getAttr().getContext();
184 }
185 
186 #undef STORAGE_SPACE_MAP_LIST
187 
188 // TODO: This is a utility function that should probably be exposed by the
189 // SPIR-V dialect. Keeping it local till the use case arises.
190 static Optional<int64_t>
191 getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
192   if (type.isa<spirv::ScalarType>()) {
193     auto bitWidth = type.getIntOrFloatBitWidth();
194     // According to the SPIR-V spec:
195     // "There is no physical size or bit pattern defined for values with boolean
196     // type. If they are stored (in conjunction with OpVariable), they can only
197     // be used with logical addressing operations, not physical, and only with
198     // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
199     // Private, Function, Input, and Output."
200     if (bitWidth == 1)
201       return llvm::None;
202     return bitWidth / 8;
203   }
204 
205   if (auto vecType = type.dyn_cast<VectorType>()) {
206     auto elementSize = getTypeNumBytes(options, vecType.getElementType());
207     if (!elementSize)
208       return llvm::None;
209     return vecType.getNumElements() * elementSize.getValue();
210   }
211 
212   if (auto memRefType = type.dyn_cast<MemRefType>()) {
213     // TODO: Layout should also be controlled by the ABI attributes. For now
214     // using the layout from MemRef.
215     int64_t offset;
216     SmallVector<int64_t, 4> strides;
217     if (!memRefType.hasStaticShape() ||
218         failed(getStridesAndOffset(memRefType, strides, offset)))
219       return llvm::None;
220 
221     // To get the size of the memref object in memory, the total size is the
222     // max(stride * dimension-size) computed for all dimensions times the size
223     // of the element.
224     auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
225     if (!elementSize)
226       return llvm::None;
227 
228     if (memRefType.getRank() == 0)
229       return elementSize;
230 
231     auto dims = memRefType.getShape();
232     if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
233         offset == MemRefType::getDynamicStrideOrOffset() ||
234         llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
235       return llvm::None;
236 
237     int64_t memrefSize = -1;
238     for (auto shape : enumerate(dims))
239       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
240 
241     return (offset + memrefSize) * elementSize.getValue();
242   }
243 
244   if (auto tensorType = type.dyn_cast<TensorType>()) {
245     if (!tensorType.hasStaticShape())
246       return llvm::None;
247 
248     auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
249     if (!elementSize)
250       return llvm::None;
251 
252     int64_t size = elementSize.getValue();
253     for (auto shape : tensorType.getShape())
254       size *= shape;
255 
256     return size;
257   }
258 
259   // TODO: Add size computation for other types.
260   return llvm::None;
261 }
262 
263 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
264 static Type convertScalarType(const spirv::TargetEnv &targetEnv,
265                               const SPIRVTypeConverter::Options &options,
266                               spirv::ScalarType type,
267                               Optional<spirv::StorageClass> storageClass = {}) {
268   // Get extension and capability requirements for the given type.
269   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
270   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
271   type.getExtensions(extensions, storageClass);
272   type.getCapabilities(capabilities, storageClass);
273 
274   // If all requirements are met, then we can accept this type as-is.
275   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
276       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
277     return type;
278 
279   // Otherwise we need to adjust the type, which really means adjusting the
280   // bitwidth given this is a scalar type.
281 
282   if (!options.emulateNon32BitScalarTypes)
283     return nullptr;
284 
285   if (auto floatType = type.dyn_cast<FloatType>()) {
286     LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
287     return Builder(targetEnv.getContext()).getF32Type();
288   }
289 
290   auto intType = type.cast<IntegerType>();
291   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
292   return IntegerType::get(targetEnv.getContext(), /*width=*/32,
293                           intType.getSignedness());
294 }
295 
296 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
297 static Type convertVectorType(const spirv::TargetEnv &targetEnv,
298                               const SPIRVTypeConverter::Options &options,
299                               VectorType type,
300                               Optional<spirv::StorageClass> storageClass = {}) {
301   if (type.getRank() == 1 && type.getNumElements() == 1)
302     return type.getElementType();
303 
304   if (!spirv::CompositeType::isValid(type)) {
305     // TODO: Vector types with more than four elements can be translated into
306     // array types.
307     LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
308     return nullptr;
309   }
310 
311   // Get extension and capability requirements for the given type.
312   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
313   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
314   type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
315   type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
316 
317   // If all requirements are met, then we can accept this type as-is.
318   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
319       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
320     return type;
321 
322   auto elementType = convertScalarType(
323       targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
324       storageClass);
325   if (elementType)
326     return VectorType::get(type.getShape(), elementType);
327   return nullptr;
328 }
329 
330 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
331 ///
332 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
333 /// create composite constants with OpConstantComposite to embed relative large
334 /// constant values and use OpCompositeExtract and OpCompositeInsert to
335 /// manipulate, like what we do for vectors.
336 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
337                               const SPIRVTypeConverter::Options &options,
338                               TensorType type) {
339   // TODO: Handle dynamic shapes.
340   if (!type.hasStaticShape()) {
341     LLVM_DEBUG(llvm::dbgs()
342                << type << " illegal: dynamic shape unimplemented\n");
343     return nullptr;
344   }
345 
346   auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
347   if (!scalarType) {
348     LLVM_DEBUG(llvm::dbgs()
349                << type << " illegal: cannot convert non-scalar element type\n");
350     return nullptr;
351   }
352 
353   Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
354   Optional<int64_t> tensorSize = getTypeNumBytes(options, type);
355   if (!scalarSize || !tensorSize) {
356     LLVM_DEBUG(llvm::dbgs()
357                << type << " illegal: cannot deduce element count\n");
358     return nullptr;
359   }
360 
361   auto arrayElemCount = *tensorSize / *scalarSize;
362   auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
363   if (!arrayElemType)
364     return nullptr;
365   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
366   if (!arrayElemSize) {
367     LLVM_DEBUG(llvm::dbgs()
368                << type << " illegal: cannot deduce converted element size\n");
369     return nullptr;
370   }
371 
372   return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
373 }
374 
375 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
376                                   const SPIRVTypeConverter::Options &options,
377                                   MemRefType type) {
378   Optional<spirv::StorageClass> storageClass =
379       SPIRVTypeConverter::getStorageClassForMemorySpace(
380           type.getMemorySpaceAsInt());
381   if (!storageClass) {
382     LLVM_DEBUG(llvm::dbgs()
383                << type << " illegal: cannot convert memory space\n");
384     return nullptr;
385   }
386 
387   unsigned numBoolBits = options.boolNumBits;
388   if (numBoolBits != 8) {
389     LLVM_DEBUG(llvm::dbgs()
390                << "using non-8-bit storage for bool types unimplemented");
391     return nullptr;
392   }
393   auto elementType = IntegerType::get(type.getContext(), numBoolBits)
394                          .dyn_cast<spirv::ScalarType>();
395   if (!elementType)
396     return nullptr;
397   Type arrayElemType =
398       convertScalarType(targetEnv, options, elementType, storageClass);
399   if (!arrayElemType)
400     return nullptr;
401   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
402   if (!arrayElemSize) {
403     LLVM_DEBUG(llvm::dbgs()
404                << type << " illegal: cannot deduce converted element size\n");
405     return nullptr;
406   }
407 
408   if (!type.hasStaticShape()) {
409     auto arrayType =
410         spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
411     return wrapInStructAndGetPointer(arrayType, *storageClass);
412   }
413 
414   int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
415   auto arrayElemCount = (memrefSize + *arrayElemSize - 1) / *arrayElemSize;
416   auto arrayType =
417       spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
418 
419   return wrapInStructAndGetPointer(arrayType, *storageClass);
420 }
421 
422 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
423                               const SPIRVTypeConverter::Options &options,
424                               MemRefType type) {
425   if (type.getElementType().isa<IntegerType>() &&
426       type.getElementTypeBitWidth() == 1) {
427     return convertBoolMemrefType(targetEnv, options, type);
428   }
429 
430   Optional<spirv::StorageClass> storageClass =
431       SPIRVTypeConverter::getStorageClassForMemorySpace(
432           type.getMemorySpaceAsInt());
433   if (!storageClass) {
434     LLVM_DEBUG(llvm::dbgs()
435                << type << " illegal: cannot convert memory space\n");
436     return nullptr;
437   }
438 
439   Type arrayElemType;
440   Type elementType = type.getElementType();
441   if (auto vecType = elementType.dyn_cast<VectorType>()) {
442     arrayElemType =
443         convertVectorType(targetEnv, options, vecType, storageClass);
444   } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
445     arrayElemType =
446         convertScalarType(targetEnv, options, scalarType, storageClass);
447   } else {
448     LLVM_DEBUG(
449         llvm::dbgs()
450         << type
451         << " unhandled: can only convert scalar or vector element type\n");
452     return nullptr;
453   }
454   if (!arrayElemType)
455     return nullptr;
456 
457   Optional<int64_t> elementSize = getTypeNumBytes(options, elementType);
458   if (!elementSize) {
459     LLVM_DEBUG(llvm::dbgs()
460                << type << " illegal: cannot deduce element size\n");
461     return nullptr;
462   }
463 
464   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
465   if (!arrayElemSize) {
466     LLVM_DEBUG(llvm::dbgs()
467                << type << " illegal: cannot deduce converted element size\n");
468     return nullptr;
469   }
470 
471   if (!type.hasStaticShape()) {
472     auto arrayType =
473         spirv::RuntimeArrayType::get(arrayElemType, *arrayElemSize);
474     return wrapInStructAndGetPointer(arrayType, *storageClass);
475   }
476 
477   Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
478   if (!memrefSize) {
479     LLVM_DEBUG(llvm::dbgs()
480                << type << " illegal: cannot deduce element count\n");
481     return nullptr;
482   }
483 
484   auto arrayElemCount = *memrefSize / *elementSize;
485 
486 
487   auto arrayType =
488       spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
489 
490   return wrapInStructAndGetPointer(arrayType, *storageClass);
491 }
492 
493 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
494                                        Options options)
495     : targetEnv(targetAttr), options(options) {
496   // Add conversions. The order matters here: later ones will be tried earlier.
497 
498   // Allow all SPIR-V dialect specific types. This assumes all builtin types
499   // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
500   // were tried before.
501   //
502   // TODO: this assumes that the SPIR-V types are valid to use in
503   // the given target environment, which should be the case if the whole
504   // pipeline is driven by the same target environment. Still, we probably still
505   // want to validate and convert to be safe.
506   addConversion([](spirv::SPIRVType type) { return type; });
507 
508   addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
509 
510   addConversion([this](IntegerType intType) -> Optional<Type> {
511     if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
512       return convertScalarType(this->targetEnv, this->options, scalarType);
513     return Type();
514   });
515 
516   addConversion([this](FloatType floatType) -> Optional<Type> {
517     if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
518       return convertScalarType(this->targetEnv, this->options, scalarType);
519     return Type();
520   });
521 
522   addConversion([this](VectorType vectorType) {
523     return convertVectorType(this->targetEnv, this->options, vectorType);
524   });
525 
526   addConversion([this](TensorType tensorType) {
527     return convertTensorType(this->targetEnv, this->options, tensorType);
528   });
529 
530   addConversion([this](MemRefType memRefType) {
531     return convertMemrefType(this->targetEnv, this->options, memRefType);
532   });
533 }
534 
535 //===----------------------------------------------------------------------===//
536 // FuncOp Conversion Patterns
537 //===----------------------------------------------------------------------===//
538 
539 namespace {
540 /// A pattern for rewriting function signature to convert arguments of functions
541 /// to be of valid SPIR-V types.
542 class FuncOpConversion final : public OpConversionPattern<FuncOp> {
543 public:
544   using OpConversionPattern<FuncOp>::OpConversionPattern;
545 
546   LogicalResult
547   matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
548                   ConversionPatternRewriter &rewriter) const override;
549 };
550 } // namespace
551 
552 LogicalResult
553 FuncOpConversion::matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor,
554                                   ConversionPatternRewriter &rewriter) const {
555   auto fnType = funcOp.getType();
556   if (fnType.getNumResults() > 1)
557     return failure();
558 
559   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
560   for (auto argType : enumerate(fnType.getInputs())) {
561     auto convertedType = getTypeConverter()->convertType(argType.value());
562     if (!convertedType)
563       return failure();
564     signatureConverter.addInputs(argType.index(), convertedType);
565   }
566 
567   Type resultType;
568   if (fnType.getNumResults() == 1) {
569     resultType = getTypeConverter()->convertType(fnType.getResult(0));
570     if (!resultType)
571       return failure();
572   }
573 
574   // Create the converted spv.func op.
575   auto newFuncOp = rewriter.create<spirv::FuncOp>(
576       funcOp.getLoc(), funcOp.getName(),
577       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
578                                resultType ? TypeRange(resultType)
579                                           : TypeRange()));
580 
581   // Copy over all attributes other than the function name and type.
582   for (const auto &namedAttr : funcOp->getAttrs()) {
583     if (namedAttr.getName() != function_like_impl::getTypeAttrName() &&
584         namedAttr.getName() != SymbolTable::getSymbolAttrName())
585       newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
586   }
587 
588   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
589                               newFuncOp.end());
590   if (failed(rewriter.convertRegionTypes(
591           &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
592     return failure();
593   rewriter.eraseOp(funcOp);
594   return success();
595 }
596 
597 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
598                                               RewritePatternSet &patterns) {
599   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
600 }
601 
602 //===----------------------------------------------------------------------===//
603 // Builtin Variables
604 //===----------------------------------------------------------------------===//
605 
606 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
607                                                   spirv::BuiltIn builtin) {
608   // Look through all global variables in the given `body` block and check if
609   // there is a spv.GlobalVariable that has the same `builtin` attribute.
610   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
611     if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
612             spirv::SPIRVDialect::getAttributeName(
613                 spirv::Decoration::BuiltIn))) {
614       auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
615       if (varBuiltIn && varBuiltIn.getValue() == builtin) {
616         return varOp;
617       }
618     }
619   }
620   return nullptr;
621 }
622 
623 /// Gets name of global variable for a builtin.
624 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
625   return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
626 }
627 
628 /// Gets or inserts a global variable for a builtin within `body` block.
629 static spirv::GlobalVariableOp
630 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
631                            Type integerType, OpBuilder &builder) {
632   if (auto varOp = getBuiltinVariable(body, builtin))
633     return varOp;
634 
635   OpBuilder::InsertionGuard guard(builder);
636   builder.setInsertionPointToStart(&body);
637 
638   spirv::GlobalVariableOp newVarOp;
639   switch (builtin) {
640   case spirv::BuiltIn::NumWorkgroups:
641   case spirv::BuiltIn::WorkgroupSize:
642   case spirv::BuiltIn::WorkgroupId:
643   case spirv::BuiltIn::LocalInvocationId:
644   case spirv::BuiltIn::GlobalInvocationId: {
645     auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
646                                            spirv::StorageClass::Input);
647     std::string name = getBuiltinVarName(builtin);
648     newVarOp =
649         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
650     break;
651   }
652   case spirv::BuiltIn::SubgroupId:
653   case spirv::BuiltIn::NumSubgroups:
654   case spirv::BuiltIn::SubgroupSize: {
655     auto ptrType =
656         spirv::PointerType::get(integerType, spirv::StorageClass::Input);
657     std::string name = getBuiltinVarName(builtin);
658     newVarOp =
659         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
660     break;
661   }
662   default:
663     emitError(loc, "unimplemented builtin variable generation for ")
664         << stringifyBuiltIn(builtin);
665   }
666   return newVarOp;
667 }
668 
669 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
670                                            spirv::BuiltIn builtin,
671                                            Type integerType,
672                                            OpBuilder &builder) {
673   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
674   if (!parent) {
675     op->emitError("expected operation to be within a module-like op");
676     return nullptr;
677   }
678 
679   spirv::GlobalVariableOp varOp =
680       getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
681                                  builtin, integerType, builder);
682   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
683   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
684 }
685 
686 //===----------------------------------------------------------------------===//
687 // Push constant storage
688 //===----------------------------------------------------------------------===//
689 
690 /// Returns the pointer type for the push constant storage containing
691 /// `elementCount` 32-bit integer values.
692 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
693                                                      Builder &builder,
694                                                      Type indexType) {
695   auto arrayType = spirv::ArrayType::get(indexType, elementCount,
696                                          /*stride=*/4);
697   auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
698   return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
699 }
700 
701 /// Returns the push constant varible containing `elementCount` 32-bit integer
702 /// values in `body`. Returns null op if such an op does not exit.
703 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
704                                                        unsigned elementCount) {
705   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
706     auto ptrType = varOp.type().dyn_cast<spirv::PointerType>();
707     if (!ptrType)
708       continue;
709 
710     // Note that Vulkan requires "There must be no more than one push constant
711     // block statically used per shader entry point." So we should always reuse
712     // the existing one.
713     if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
714       auto numElements = ptrType.getPointeeType()
715                              .cast<spirv::StructType>()
716                              .getElementType(0)
717                              .cast<spirv::ArrayType>()
718                              .getNumElements();
719       if (numElements == elementCount)
720         return varOp;
721     }
722   }
723   return nullptr;
724 }
725 
726 /// Gets or inserts a global variable for push constant storage containing
727 /// `elementCount` 32-bit integer values in `block`.
728 static spirv::GlobalVariableOp
729 getOrInsertPushConstantVariable(Location loc, Block &block,
730                                 unsigned elementCount, OpBuilder &b,
731                                 Type indexType) {
732   if (auto varOp = getPushConstantVariable(block, elementCount))
733     return varOp;
734 
735   auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
736   auto type = getPushConstantStorageType(elementCount, builder, indexType);
737   const char *name = "__push_constant_var__";
738   return builder.create<spirv::GlobalVariableOp>(loc, type, name,
739                                                  /*initializer=*/nullptr);
740 }
741 
742 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
743                                   unsigned offset, Type integerType,
744                                   OpBuilder &builder) {
745   Location loc = op->getLoc();
746   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
747   if (!parent) {
748     op->emitError("expected operation to be within a module-like op");
749     return nullptr;
750   }
751 
752   spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
753       loc, parent->getRegion(0).front(), elementCount, builder, integerType);
754 
755   Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
756   Value offsetOp = builder.create<spirv::ConstantOp>(
757       loc, integerType, builder.getI32IntegerAttr(offset));
758   auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
759   auto acOp = builder.create<spirv::AccessChainOp>(
760       loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
761   return builder.create<spirv::LoadOp>(loc, acOp);
762 }
763 
764 //===----------------------------------------------------------------------===//
765 // Index calculation
766 //===----------------------------------------------------------------------===//
767 
768 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
769                                   int64_t offset, Type integerType,
770                                   Location loc, OpBuilder &builder) {
771   assert(indices.size() == strides.size() &&
772          "must provide indices for all dimensions");
773 
774   // TODO: Consider moving to use affine.apply and patterns converting
775   // affine.apply to standard ops. This needs converting to SPIR-V passes to be
776   // broken down into progressive small steps so we can have intermediate steps
777   // using other dialects. At the moment SPIR-V is the final sink.
778 
779   Value linearizedIndex = builder.create<spirv::ConstantOp>(
780       loc, integerType, IntegerAttr::get(integerType, offset));
781   for (auto index : llvm::enumerate(indices)) {
782     Value strideVal = builder.create<spirv::ConstantOp>(
783         loc, integerType,
784         IntegerAttr::get(integerType, strides[index.index()]));
785     Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
786     linearizedIndex =
787         builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
788   }
789   return linearizedIndex;
790 }
791 
792 spirv::AccessChainOp mlir::spirv::getElementPtr(
793     SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
794     ValueRange indices, Location loc, OpBuilder &builder) {
795   // Get base and offset of the MemRefType and verify they are static.
796 
797   int64_t offset;
798   SmallVector<int64_t, 4> strides;
799   if (failed(getStridesAndOffset(baseType, strides, offset)) ||
800       llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
801       offset == MemRefType::getDynamicStrideOrOffset()) {
802     return nullptr;
803   }
804 
805   auto indexType = typeConverter.getIndexType();
806 
807   SmallVector<Value, 2> linearizedIndices;
808   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
809 
810   // Add a '0' at the start to index into the struct.
811   linearizedIndices.push_back(zero);
812 
813   if (baseType.getRank() == 0) {
814     linearizedIndices.push_back(zero);
815   } else {
816     linearizedIndices.push_back(
817         linearizeIndex(indices, strides, offset, indexType, loc, builder));
818   }
819   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // SPIR-V ConversionTarget
824 //===----------------------------------------------------------------------===//
825 
826 std::unique_ptr<SPIRVConversionTarget>
827 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
828   std::unique_ptr<SPIRVConversionTarget> target(
829       // std::make_unique does not work here because the constructor is private.
830       new SPIRVConversionTarget(targetAttr));
831   SPIRVConversionTarget *targetPtr = target.get();
832   target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
833       // We need to capture the raw pointer here because it is stable:
834       // target will be destroyed once this function is returned.
835       [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
836   return target;
837 }
838 
839 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
840     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
841 
842 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
843   // Make sure this op is available at the given version. Ops not implementing
844   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
845   // SPIR-V versions.
846   if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
847     Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
848     if (minVersion && *minVersion > this->targetEnv.getVersion()) {
849       LLVM_DEBUG(llvm::dbgs()
850                  << op->getName() << " illegal: requiring min version "
851                  << spirv::stringifyVersion(*minVersion) << "\n");
852       return false;
853     }
854   }
855   if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
856     Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
857     if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
858       LLVM_DEBUG(llvm::dbgs()
859                  << op->getName() << " illegal: requiring max version "
860                  << spirv::stringifyVersion(*maxVersion) << "\n");
861       return false;
862     }
863   }
864 
865   // Make sure this op's required extensions are allowed to use. Ops not
866   // implementing QueryExtensionInterface do not require extensions to be
867   // available.
868   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
869     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
870                                           extensions.getExtensions())))
871       return false;
872 
873   // Make sure this op's required extensions are allowed to use. Ops not
874   // implementing QueryCapabilityInterface do not require capabilities to be
875   // available.
876   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
877     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
878                                            capabilities.getCapabilities())))
879       return false;
880 
881   SmallVector<Type, 4> valueTypes;
882   valueTypes.append(op->operand_type_begin(), op->operand_type_end());
883   valueTypes.append(op->result_type_begin(), op->result_type_end());
884 
885   // Ensure that all types have been converted to SPIRV types.
886   if (llvm::any_of(valueTypes,
887                    [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
888     return false;
889 
890   // Special treatment for global variables, whose type requirements are
891   // conveyed by type attributes.
892   if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
893     valueTypes.push_back(globalVar.type());
894 
895   // Make sure the op's operands/results use types that are allowed by the
896   // target environment.
897   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
898   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
899   for (Type valueType : valueTypes) {
900     typeExtensions.clear();
901     valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
902     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
903                                           typeExtensions)))
904       return false;
905 
906     typeCapabilities.clear();
907     valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
908     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
909                                            typeCapabilities)))
910       return false;
911   }
912 
913   return true;
914 }
915