1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "llvm/ADT/Sequence.h"
17 #include "llvm/ADT/StringExtras.h"
18 #include "llvm/Support/Debug.h"
19 
20 #include <functional>
21 
22 #define DEBUG_TYPE "mlir-spirv-conversion"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Utility functions
28 //===----------------------------------------------------------------------===//
29 
30 /// Checks that `candidates` extension requirements are possible to be satisfied
31 /// with the given `targetEnv`.
32 ///
33 ///  `candidates` is a vector of vector for extension requirements following
34 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
35 /// convention.
36 template <typename LabelT>
37 static LogicalResult checkExtensionRequirements(
38     LabelT label, const spirv::TargetEnv &targetEnv,
39     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
40   for (const auto &ors : candidates) {
41     if (targetEnv.allows(ors))
42       continue;
43 
44     SmallVector<StringRef, 4> extStrings;
45     for (spirv::Extension ext : ors)
46       extStrings.push_back(spirv::stringifyExtension(ext));
47 
48     LLVM_DEBUG(llvm::dbgs()
49                << label << " illegal: requires at least one extension in ["
50                << llvm::join(extStrings, ", ")
51                << "] but none allowed in target environment\n");
52     return failure();
53   }
54   return success();
55 }
56 
57 /// Checks that `candidates`capability requirements are possible to be satisfied
58 /// with the given `isAllowedFn`.
59 ///
60 ///  `candidates` is a vector of vector for capability requirements following
61 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
62 /// convention.
63 template <typename LabelT>
64 static LogicalResult checkCapabilityRequirements(
65     LabelT label, const spirv::TargetEnv &targetEnv,
66     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
67   for (const auto &ors : candidates) {
68     if (targetEnv.allows(ors))
69       continue;
70 
71     SmallVector<StringRef, 4> capStrings;
72     for (spirv::Capability cap : ors)
73       capStrings.push_back(spirv::stringifyCapability(cap));
74 
75     LLVM_DEBUG(llvm::dbgs()
76                << label << " illegal: requires at least one capability in ["
77                << llvm::join(capStrings, ", ")
78                << "] but none allowed in target environment\n");
79     return failure();
80   }
81   return success();
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // Type Conversion
86 //===----------------------------------------------------------------------===//
87 
88 Type SPIRVTypeConverter::getIndexType(MLIRContext *context) {
89   // Convert to 32-bit integers for now. Might need a way to control this in
90   // future.
91   // TODO: It is probably better to make it 64-bit integers. To
92   // this some support is needed in SPIR-V dialect for Conversion
93   // instructions. The Vulkan spec requires the builtins like
94   // GlobalInvocationID, etc. to be 32-bit (unsigned) integers which should be
95   // SExtended to 64-bit for index computations.
96   return IntegerType::get(context, 32);
97 }
98 
99 /// Mapping between SPIR-V storage classes to memref memory spaces.
100 ///
101 /// Note: memref does not have a defined semantics for each memory space; it
102 /// depends on the context where it is used. There are no particular reasons
103 /// behind the number assignments; we try to follow NVVM conventions and largely
104 /// give common storage classes a smaller number. The hope is use symbolic
105 /// memory space representation eventually after memref supports it.
106 // TODO: swap Generic and StorageBuffer assignment to be more akin
107 // to NVVM.
108 #define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
109   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
110   MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
111   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
112   MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
113   MAP_FN(spirv::StorageClass::Private, 5)                                      \
114   MAP_FN(spirv::StorageClass::Function, 6)                                     \
115   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
116   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
117   MAP_FN(spirv::StorageClass::Input, 9)                                        \
118   MAP_FN(spirv::StorageClass::Output, 10)                                      \
119   MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
120   MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
121   MAP_FN(spirv::StorageClass::Image, 13)                                       \
122   MAP_FN(spirv::StorageClass::CallableDataNV, 14)                              \
123   MAP_FN(spirv::StorageClass::IncomingCallableDataNV, 15)                      \
124   MAP_FN(spirv::StorageClass::RayPayloadNV, 16)                                \
125   MAP_FN(spirv::StorageClass::HitAttributeNV, 17)                              \
126   MAP_FN(spirv::StorageClass::IncomingRayPayloadNV, 18)                        \
127   MAP_FN(spirv::StorageClass::ShaderRecordBufferNV, 19)                        \
128   MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)
129 
130 unsigned
131 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
132 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
133   case storage:                                                                \
134     return space;
135 
136   switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
137 #undef STORAGE_SPACE_MAP_FN
138   llvm_unreachable("unhandled storage class!");
139 }
140 
141 Optional<spirv::StorageClass>
142 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
143 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
144   case space:                                                                  \
145     return storage;
146 
147   switch (space) {
148     STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
149   default:
150     return llvm::None;
151   }
152 #undef STORAGE_SPACE_MAP_FN
153 }
154 
155 #undef STORAGE_SPACE_MAP_LIST
156 
157 // TODO: This is a utility function that should probably be
158 // exposed by the SPIR-V dialect. Keeping it local till the use case arises.
159 static Optional<int64_t> getTypeNumBytes(Type t) {
160   if (t.isa<spirv::ScalarType>()) {
161     auto bitWidth = t.getIntOrFloatBitWidth();
162     // According to the SPIR-V spec:
163     // "There is no physical size or bit pattern defined for values with boolean
164     // type. If they are stored (in conjunction with OpVariable), they can only
165     // be used with logical addressing operations, not physical, and only with
166     // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
167     // Private, Function, Input, and Output."
168     if (bitWidth == 1) {
169       return llvm::None;
170     }
171     return bitWidth / 8;
172   }
173   if (auto vecType = t.dyn_cast<VectorType>()) {
174     auto elementSize = getTypeNumBytes(vecType.getElementType());
175     if (!elementSize)
176       return llvm::None;
177     return vecType.getNumElements() * *elementSize;
178   }
179   if (auto memRefType = t.dyn_cast<MemRefType>()) {
180     // TODO: Layout should also be controlled by the ABI attributes. For now
181     // using the layout from MemRef.
182     int64_t offset;
183     SmallVector<int64_t, 4> strides;
184     if (!memRefType.hasStaticShape() ||
185         failed(getStridesAndOffset(memRefType, strides, offset))) {
186       return llvm::None;
187     }
188     // To get the size of the memref object in memory, the total size is the
189     // max(stride * dimension-size) computed for all dimensions times the size
190     // of the element.
191     auto elementSize = getTypeNumBytes(memRefType.getElementType());
192     if (!elementSize) {
193       return llvm::None;
194     }
195     if (memRefType.getRank() == 0) {
196       return elementSize;
197     }
198     auto dims = memRefType.getShape();
199     if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
200         offset == MemRefType::getDynamicStrideOrOffset() ||
201         llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
202       return llvm::None;
203     }
204     int64_t memrefSize = -1;
205     for (auto shape : enumerate(dims)) {
206       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
207     }
208     return (offset + memrefSize) * elementSize.getValue();
209   } else if (auto tensorType = t.dyn_cast<TensorType>()) {
210     if (!tensorType.hasStaticShape()) {
211       return llvm::None;
212     }
213     auto elementSize = getTypeNumBytes(tensorType.getElementType());
214     if (!elementSize) {
215       return llvm::None;
216     }
217     int64_t size = elementSize.getValue();
218     for (auto shape : tensorType.getShape()) {
219       size *= shape;
220     }
221     return size;
222   }
223   // TODO: Add size computation for other types.
224   return llvm::None;
225 }
226 
227 Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
228   return getTypeNumBytes(t);
229 }
230 
231 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
232 static Optional<Type>
233 convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
234                   Optional<spirv::StorageClass> storageClass = {}) {
235   // Get extension and capability requirements for the given type.
236   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
237   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
238   type.getExtensions(extensions, storageClass);
239   type.getCapabilities(capabilities, storageClass);
240 
241   // If all requirements are met, then we can accept this type as-is.
242   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
243       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
244     return type;
245 
246   // Otherwise we need to adjust the type, which really means adjusting the
247   // bitwidth given this is a scalar type.
248   // TODO: We are unconditionally converting the bitwidth here,
249   // this might be okay for non-interface types (i.e., types used in
250   // Private/Function storage classes), but not for interface types (i.e.,
251   // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes).
252   // This is because the later actually affects the ABI contract with the
253   // runtime. So we may want to expose a control on SPIRVTypeConverter to fail
254   // conversion if we cannot change there.
255 
256   if (auto floatType = type.dyn_cast<FloatType>()) {
257     LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
258     return Builder(targetEnv.getContext()).getF32Type();
259   }
260 
261   auto intType = type.cast<IntegerType>();
262   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
263   return IntegerType::get(targetEnv.getContext(), /*width=*/32,
264                           intType.getSignedness());
265 }
266 
267 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
268 static Optional<Type>
269 convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
270                   Optional<spirv::StorageClass> storageClass = {}) {
271   if (!spirv::CompositeType::isValid(type)) {
272     // TODO: One-element vector types can be translated into scalar
273     // types. Vector types with more than four elements can be translated into
274     // array types.
275     LLVM_DEBUG(llvm::dbgs()
276                << type << " illegal: 1- and > 4-element unimplemented\n");
277     return llvm::None;
278   }
279 
280   // Get extension and capability requirements for the given type.
281   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
282   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
283   type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
284   type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
285 
286   // If all requirements are met, then we can accept this type as-is.
287   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
288       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
289     return type;
290 
291   auto elementType = convertScalarType(
292       targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
293   if (elementType)
294     return VectorType::get(type.getShape(), *elementType);
295   return llvm::None;
296 }
297 
298 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
299 ///
300 /// Note that this is mainly for lowering constant tensors.In SPIR-V one can
301 /// create composite constants with OpConstantComposite to embed relative large
302 /// constant values and use OpCompositeExtract and OpCompositeInsert to
303 /// manipulate, like what we do for vectors.
304 static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
305                                         TensorType type) {
306   // TODO: Handle dynamic shapes.
307   if (!type.hasStaticShape()) {
308     LLVM_DEBUG(llvm::dbgs()
309                << type << " illegal: dynamic shape unimplemented\n");
310     return llvm::None;
311   }
312 
313   auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
314   if (!scalarType) {
315     LLVM_DEBUG(llvm::dbgs()
316                << type << " illegal: cannot convert non-scalar element type\n");
317     return llvm::None;
318   }
319 
320   Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
321   Optional<int64_t> tensorSize = getTypeNumBytes(type);
322   if (!scalarSize || !tensorSize) {
323     LLVM_DEBUG(llvm::dbgs()
324                << type << " illegal: cannot deduce element count\n");
325     return llvm::None;
326   }
327 
328   auto arrayElemCount = *tensorSize / *scalarSize;
329   auto arrayElemType = convertScalarType(targetEnv, scalarType);
330   if (!arrayElemType)
331     return llvm::None;
332   Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
333   if (!arrayElemSize) {
334     LLVM_DEBUG(llvm::dbgs()
335                << type << " illegal: cannot deduce converted element size\n");
336     return llvm::None;
337   }
338 
339   return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
340 }
341 
342 static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
343                                         MemRefType type) {
344   Optional<spirv::StorageClass> storageClass =
345       SPIRVTypeConverter::getStorageClassForMemorySpace(type.getMemorySpace());
346   if (!storageClass) {
347     LLVM_DEBUG(llvm::dbgs()
348                << type << " illegal: cannot convert memory space\n");
349     return llvm::None;
350   }
351 
352   Optional<Type> arrayElemType;
353   Type elementType = type.getElementType();
354   if (auto vecType = elementType.dyn_cast<VectorType>()) {
355     arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
356   } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
357     arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
358   } else {
359     LLVM_DEBUG(
360         llvm::dbgs()
361         << type
362         << " unhandled: can only convert scalar or vector element type\n");
363     return llvm::None;
364   }
365   if (!arrayElemType)
366     return llvm::None;
367 
368   Optional<int64_t> elementSize = getTypeNumBytes(elementType);
369   if (!elementSize) {
370     LLVM_DEBUG(llvm::dbgs()
371                << type << " illegal: cannot deduce element size\n");
372     return llvm::None;
373   }
374 
375   if (!type.hasStaticShape()) {
376     auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
377     // Wrap in a struct to satisfy Vulkan interface requirements.
378     auto structType = spirv::StructType::get(arrayType, 0);
379     return spirv::PointerType::get(structType, *storageClass);
380   }
381 
382   Optional<int64_t> memrefSize = getTypeNumBytes(type);
383   if (!memrefSize) {
384     LLVM_DEBUG(llvm::dbgs()
385                << type << " illegal: cannot deduce element count\n");
386     return llvm::None;
387   }
388 
389   auto arrayElemCount = *memrefSize / *elementSize;
390 
391   Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
392   if (!arrayElemSize) {
393     LLVM_DEBUG(llvm::dbgs()
394                << type << " illegal: cannot deduce converted element size\n");
395     return llvm::None;
396   }
397 
398   auto arrayType =
399       spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
400 
401   // Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
402   // workgroup storage class do not need the struct to be laid out explicitly.
403   auto structType = *storageClass == spirv::StorageClass::Workgroup
404                         ? spirv::StructType::get(arrayType)
405                         : spirv::StructType::get(arrayType, 0);
406   return spirv::PointerType::get(structType, *storageClass);
407 }
408 
409 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
410     : targetEnv(targetAttr) {
411   // Add conversions. The order matters here: later ones will be tried earlier.
412 
413   // All other cases failed. Then we cannot convert this type.
414   addConversion([](Type type) { return llvm::None; });
415 
416   // Allow all SPIR-V dialect specific types. This assumes all builtin types
417   // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
418   // were tried before.
419   //
420   // TODO: this assumes that the SPIR-V types are valid to use in
421   // the given target environment, which should be the case if the whole
422   // pipeline is driven by the same target environment. Still, we probably still
423   // want to validate and convert to be safe.
424   addConversion([](spirv::SPIRVType type) { return type; });
425 
426   addConversion([](IndexType indexType) {
427     return SPIRVTypeConverter::getIndexType(indexType.getContext());
428   });
429 
430   addConversion([this](IntegerType intType) -> Optional<Type> {
431     if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
432       return convertScalarType(targetEnv, scalarType);
433     return llvm::None;
434   });
435 
436   addConversion([this](FloatType floatType) -> Optional<Type> {
437     if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
438       return convertScalarType(targetEnv, scalarType);
439     return llvm::None;
440   });
441 
442   addConversion([this](VectorType vectorType) {
443     return convertVectorType(targetEnv, vectorType);
444   });
445 
446   addConversion([this](TensorType tensorType) {
447     return convertTensorType(targetEnv, tensorType);
448   });
449 
450   addConversion([this](MemRefType memRefType) {
451     return convertMemrefType(targetEnv, memRefType);
452   });
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // FuncOp Conversion Patterns
457 //===----------------------------------------------------------------------===//
458 
459 namespace {
460 /// A pattern for rewriting function signature to convert arguments of functions
461 /// to be of valid SPIR-V types.
462 class FuncOpConversion final : public SPIRVOpLowering<FuncOp> {
463 public:
464   using SPIRVOpLowering<FuncOp>::SPIRVOpLowering;
465 
466   LogicalResult
467   matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
468                   ConversionPatternRewriter &rewriter) const override;
469 };
470 } // namespace
471 
472 LogicalResult
473 FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
474                                   ConversionPatternRewriter &rewriter) const {
475   auto fnType = funcOp.getType();
476   if (fnType.getNumResults() > 1)
477     return failure();
478 
479   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
480   for (auto argType : enumerate(fnType.getInputs())) {
481     auto convertedType = typeConverter.convertType(argType.value());
482     if (!convertedType)
483       return failure();
484     signatureConverter.addInputs(argType.index(), convertedType);
485   }
486 
487   Type resultType;
488   if (fnType.getNumResults() == 1)
489     resultType = typeConverter.convertType(fnType.getResult(0));
490 
491   // Create the converted spv.func op.
492   auto newFuncOp = rewriter.create<spirv::FuncOp>(
493       funcOp.getLoc(), funcOp.getName(),
494       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
495                                resultType ? TypeRange(resultType)
496                                           : TypeRange()));
497 
498   // Copy over all attributes other than the function name and type.
499   for (const auto &namedAttr : funcOp.getAttrs()) {
500     if (namedAttr.first != impl::getTypeAttrName() &&
501         namedAttr.first != SymbolTable::getSymbolAttrName())
502       newFuncOp->setAttr(namedAttr.first, namedAttr.second);
503   }
504 
505   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
506                               newFuncOp.end());
507   if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), typeConverter,
508                                          &signatureConverter)))
509     return failure();
510   rewriter.eraseOp(funcOp);
511   return success();
512 }
513 
514 void mlir::populateBuiltinFuncToSPIRVPatterns(
515     MLIRContext *context, SPIRVTypeConverter &typeConverter,
516     OwningRewritePatternList &patterns) {
517   patterns.insert<FuncOpConversion>(context, typeConverter);
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // Builtin Variables
522 //===----------------------------------------------------------------------===//
523 
524 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
525                                                   spirv::BuiltIn builtin) {
526   // Look through all global variables in the given `body` block and check if
527   // there is a spv.globalVariable that has the same `builtin` attribute.
528   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
529     if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
530             spirv::SPIRVDialect::getAttributeName(
531                 spirv::Decoration::BuiltIn))) {
532       auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
533       if (varBuiltIn && varBuiltIn.getValue() == builtin) {
534         return varOp;
535       }
536     }
537   }
538   return nullptr;
539 }
540 
541 /// Gets name of global variable for a builtin.
542 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
543   return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
544 }
545 
546 /// Gets or inserts a global variable for a builtin within `body` block.
547 static spirv::GlobalVariableOp
548 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
549                            OpBuilder &builder) {
550   if (auto varOp = getBuiltinVariable(body, builtin))
551     return varOp;
552 
553   OpBuilder::InsertionGuard guard(builder);
554   builder.setInsertionPointToStart(&body);
555 
556   spirv::GlobalVariableOp newVarOp;
557   switch (builtin) {
558   case spirv::BuiltIn::NumWorkgroups:
559   case spirv::BuiltIn::WorkgroupSize:
560   case spirv::BuiltIn::WorkgroupId:
561   case spirv::BuiltIn::LocalInvocationId:
562   case spirv::BuiltIn::GlobalInvocationId: {
563     auto ptrType = spirv::PointerType::get(
564         VectorType::get({3}, builder.getIntegerType(32)),
565         spirv::StorageClass::Input);
566     std::string name = getBuiltinVarName(builtin);
567     newVarOp =
568         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
569     break;
570   }
571   case spirv::BuiltIn::SubgroupId:
572   case spirv::BuiltIn::NumSubgroups:
573   case spirv::BuiltIn::SubgroupSize: {
574     auto ptrType = spirv::PointerType::get(builder.getIntegerType(32),
575                                            spirv::StorageClass::Input);
576     std::string name = getBuiltinVarName(builtin);
577     newVarOp =
578         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
579     break;
580   }
581   default:
582     emitError(loc, "unimplemented builtin variable generation for ")
583         << stringifyBuiltIn(builtin);
584   }
585   return newVarOp;
586 }
587 
588 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
589                                            spirv::BuiltIn builtin,
590                                            OpBuilder &builder) {
591   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
592   if (!parent) {
593     op->emitError("expected operation to be within a module-like op");
594     return nullptr;
595   }
596 
597   spirv::GlobalVariableOp varOp = getOrInsertBuiltinVariable(
598       *parent->getRegion(0).begin(), op->getLoc(), builtin, builder);
599   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
600   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
601 }
602 
603 //===----------------------------------------------------------------------===//
604 // Index calculation
605 //===----------------------------------------------------------------------===//
606 
607 spirv::AccessChainOp mlir::spirv::getElementPtr(
608     SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
609     ValueRange indices, Location loc, OpBuilder &builder) {
610   // Get base and offset of the MemRefType and verify they are static.
611 
612   int64_t offset;
613   SmallVector<int64_t, 4> strides;
614   if (failed(getStridesAndOffset(baseType, strides, offset)) ||
615       llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
616       offset == MemRefType::getDynamicStrideOrOffset()) {
617     return nullptr;
618   }
619 
620   auto indexType = typeConverter.getIndexType(builder.getContext());
621 
622   SmallVector<Value, 2> linearizedIndices;
623   // Add a '0' at the start to index into the struct.
624   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
625   linearizedIndices.push_back(zero);
626 
627   if (baseType.getRank() == 0) {
628     linearizedIndices.push_back(zero);
629   } else {
630     // TODO: Instead of this logic, use affine.apply and add patterns for
631     // lowering affine.apply to standard ops. These will get lowered to SPIR-V
632     // ops by the DialectConversion framework.
633     Value ptrLoc = builder.create<spirv::ConstantOp>(
634         loc, indexType, IntegerAttr::get(indexType, offset));
635     assert(indices.size() == strides.size() &&
636            "must provide indices for all dimensions");
637     for (auto index : llvm::enumerate(indices)) {
638       Value strideVal = builder.create<spirv::ConstantOp>(
639           loc, indexType, IntegerAttr::get(indexType, strides[index.index()]));
640       Value update =
641           builder.create<spirv::IMulOp>(loc, strideVal, index.value());
642       ptrLoc = builder.create<spirv::IAddOp>(loc, ptrLoc, update);
643     }
644     linearizedIndices.push_back(ptrLoc);
645   }
646   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
647 }
648 
649 //===----------------------------------------------------------------------===//
650 // Set ABI attributes for lowering entry functions.
651 //===----------------------------------------------------------------------===//
652 
653 LogicalResult
654 mlir::spirv::setABIAttrs(spirv::FuncOp funcOp,
655                          spirv::EntryPointABIAttr entryPointInfo,
656                          ArrayRef<spirv::InterfaceVarABIAttr> argABIInfo) {
657   // Set the attributes for argument and the function.
658   StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
659   for (auto argIndex : llvm::seq<unsigned>(0, argABIInfo.size())) {
660     funcOp.setArgAttr(argIndex, argABIAttrName, argABIInfo[argIndex]);
661   }
662   funcOp->setAttr(spirv::getEntryPointABIAttrName(), entryPointInfo);
663   return success();
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // SPIR-V ConversionTarget
668 //===----------------------------------------------------------------------===//
669 
670 std::unique_ptr<spirv::SPIRVConversionTarget>
671 spirv::SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
672   std::unique_ptr<SPIRVConversionTarget> target(
673       // std::make_unique does not work here because the constructor is private.
674       new SPIRVConversionTarget(targetAttr));
675   SPIRVConversionTarget *targetPtr = target.get();
676   target->addDynamicallyLegalDialect<SPIRVDialect>(
677       // We need to capture the raw pointer here because it is stable:
678       // target will be destroyed once this function is returned.
679       [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
680   return target;
681 }
682 
683 spirv::SPIRVConversionTarget::SPIRVConversionTarget(
684     spirv::TargetEnvAttr targetAttr)
685     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
686 
687 bool spirv::SPIRVConversionTarget::isLegalOp(Operation *op) {
688   // Make sure this op is available at the given version. Ops not implementing
689   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
690   // SPIR-V versions.
691   if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
692     if (minVersion.getMinVersion() > this->targetEnv.getVersion()) {
693       LLVM_DEBUG(llvm::dbgs()
694                  << op->getName() << " illegal: requiring min version "
695                  << spirv::stringifyVersion(minVersion.getMinVersion())
696                  << "\n");
697       return false;
698     }
699   if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
700     if (maxVersion.getMaxVersion() < this->targetEnv.getVersion()) {
701       LLVM_DEBUG(llvm::dbgs()
702                  << op->getName() << " illegal: requiring max version "
703                  << spirv::stringifyVersion(maxVersion.getMaxVersion())
704                  << "\n");
705       return false;
706     }
707 
708   // Make sure this op's required extensions are allowed to use. Ops not
709   // implementing QueryExtensionInterface do not require extensions to be
710   // available.
711   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
712     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
713                                           extensions.getExtensions())))
714       return false;
715 
716   // Make sure this op's required extensions are allowed to use. Ops not
717   // implementing QueryCapabilityInterface do not require capabilities to be
718   // available.
719   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
720     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
721                                            capabilities.getCapabilities())))
722       return false;
723 
724   SmallVector<Type, 4> valueTypes;
725   valueTypes.append(op->operand_type_begin(), op->operand_type_end());
726   valueTypes.append(op->result_type_begin(), op->result_type_end());
727 
728   // Special treatment for global variables, whose type requirements are
729   // conveyed by type attributes.
730   if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
731     valueTypes.push_back(globalVar.type());
732 
733   // Make sure the op's operands/results use types that are allowed by the
734   // target environment.
735   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
736   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
737   for (Type valueType : valueTypes) {
738     typeExtensions.clear();
739     valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
740     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
741                                           typeExtensions)))
742       return false;
743 
744     typeCapabilities.clear();
745     valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
746     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
747                                            typeCapabilities)))
748       return false;
749   }
750 
751   return true;
752 }
753