1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Support/Debug.h"
21 
22 #include <functional>
23 
24 #define DEBUG_TYPE "mlir-spirv-conversion"
25 
26 using namespace mlir;
27 
28 //===----------------------------------------------------------------------===//
29 // Utility functions
30 //===----------------------------------------------------------------------===//
31 
32 /// Checks that `candidates` extension requirements are possible to be satisfied
33 /// with the given `targetEnv`.
34 ///
35 ///  `candidates` is a vector of vector for extension requirements following
36 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
37 /// convention.
38 template <typename LabelT>
checkExtensionRequirements(LabelT label,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::ExtensionArrayRefVector & candidates)39 static LogicalResult checkExtensionRequirements(
40     LabelT label, const spirv::TargetEnv &targetEnv,
41     const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
42   for (const auto &ors : candidates) {
43     if (targetEnv.allows(ors))
44       continue;
45 
46     LLVM_DEBUG({
47       SmallVector<StringRef> extStrings;
48       for (spirv::Extension ext : ors)
49         extStrings.push_back(spirv::stringifyExtension(ext));
50 
51       llvm::dbgs() << label << " illegal: requires at least one extension in ["
52                    << llvm::join(extStrings, ", ")
53                    << "] but none allowed in target environment\n";
54     });
55     return failure();
56   }
57   return success();
58 }
59 
60 /// Checks that `candidates`capability requirements are possible to be satisfied
61 /// with the given `isAllowedFn`.
62 ///
63 ///  `candidates` is a vector of vector for capability requirements following
64 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
65 /// convention.
66 template <typename LabelT>
checkCapabilityRequirements(LabelT label,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::CapabilityArrayRefVector & candidates)67 static LogicalResult checkCapabilityRequirements(
68     LabelT label, const spirv::TargetEnv &targetEnv,
69     const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
70   for (const auto &ors : candidates) {
71     if (targetEnv.allows(ors))
72       continue;
73 
74     LLVM_DEBUG({
75       SmallVector<StringRef> capStrings;
76       for (spirv::Capability cap : ors)
77         capStrings.push_back(spirv::stringifyCapability(cap));
78 
79       llvm::dbgs() << label << " illegal: requires at least one capability in ["
80                    << llvm::join(capStrings, ", ")
81                    << "] but none allowed in target environment\n";
82     });
83     return failure();
84   }
85   return success();
86 }
87 
88 /// Returns true if the given `storageClass` needs explicit layout when used in
89 /// Shader environments.
needsExplicitLayout(spirv::StorageClass storageClass)90 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
91   switch (storageClass) {
92   case spirv::StorageClass::PhysicalStorageBuffer:
93   case spirv::StorageClass::PushConstant:
94   case spirv::StorageClass::StorageBuffer:
95   case spirv::StorageClass::Uniform:
96     return true;
97   default:
98     return false;
99   }
100 }
101 
102 /// Wraps the given `elementType` in a struct and gets the pointer to the
103 /// struct. This is used to satisfy Vulkan interface requirements.
104 static spirv::PointerType
wrapInStructAndGetPointer(Type elementType,spirv::StorageClass storageClass)105 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
106   auto structType = needsExplicitLayout(storageClass)
107                         ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
108                         : spirv::StructType::get(elementType);
109   return spirv::PointerType::get(structType, storageClass);
110 }
111 
112 //===----------------------------------------------------------------------===//
113 // Type Conversion
114 //===----------------------------------------------------------------------===//
115 
getIndexType() const116 Type SPIRVTypeConverter::getIndexType() const {
117   return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
118 }
119 
120 /// Mapping between SPIR-V storage classes to memref memory spaces.
121 ///
122 /// Note: memref does not have a defined semantics for each memory space; it
123 /// depends on the context where it is used. There are no particular reasons
124 /// behind the number assignments; we try to follow NVVM conventions and largely
125 /// give common storage classes a smaller number. The hope is use symbolic
126 /// memory space representation eventually after memref supports it.
127 // TODO: swap Generic and StorageBuffer assignment to be more akin
128 // to NVVM.
129 #define STORAGE_SPACE_MAP_LIST(MAP_FN)                                         \
130   MAP_FN(spirv::StorageClass::Generic, 1)                                      \
131   MAP_FN(spirv::StorageClass::StorageBuffer, 0)                                \
132   MAP_FN(spirv::StorageClass::Workgroup, 3)                                    \
133   MAP_FN(spirv::StorageClass::Uniform, 4)                                      \
134   MAP_FN(spirv::StorageClass::Private, 5)                                      \
135   MAP_FN(spirv::StorageClass::Function, 6)                                     \
136   MAP_FN(spirv::StorageClass::PushConstant, 7)                                 \
137   MAP_FN(spirv::StorageClass::UniformConstant, 8)                              \
138   MAP_FN(spirv::StorageClass::Input, 9)                                        \
139   MAP_FN(spirv::StorageClass::Output, 10)                                      \
140   MAP_FN(spirv::StorageClass::CrossWorkgroup, 11)                              \
141   MAP_FN(spirv::StorageClass::AtomicCounter, 12)                               \
142   MAP_FN(spirv::StorageClass::Image, 13)                                       \
143   MAP_FN(spirv::StorageClass::CallableDataKHR, 14)                             \
144   MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15)                     \
145   MAP_FN(spirv::StorageClass::RayPayloadKHR, 16)                               \
146   MAP_FN(spirv::StorageClass::HitAttributeKHR, 17)                             \
147   MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18)                       \
148   MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19)                       \
149   MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20)                       \
150   MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21)                            \
151   MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22)                             \
152   MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
153 
154 unsigned
getMemorySpaceForStorageClass(spirv::StorageClass storage)155 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
156 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
157   case storage:                                                                \
158     return space;
159 
160   switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
161 #undef STORAGE_SPACE_MAP_FN
162   llvm_unreachable("unhandled storage class!");
163 }
164 
165 Optional<spirv::StorageClass>
getStorageClassForMemorySpace(unsigned space)166 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
167 #define STORAGE_SPACE_MAP_FN(storage, space)                                   \
168   case space:                                                                  \
169     return storage;
170 
171   switch (space) {
172     STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
173   default:
174     return llvm::None;
175   }
176 #undef STORAGE_SPACE_MAP_FN
177 }
178 
getOptions() const179 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
180   return options;
181 }
182 
getContext() const183 MLIRContext *SPIRVTypeConverter::getContext() const {
184   return targetEnv.getAttr().getContext();
185 }
186 
187 #undef STORAGE_SPACE_MAP_LIST
188 
189 // TODO: This is a utility function that should probably be exposed by the
190 // SPIR-V dialect. Keeping it local till the use case arises.
191 static Optional<int64_t>
getTypeNumBytes(const SPIRVTypeConverter::Options & options,Type type)192 getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
193   if (type.isa<spirv::ScalarType>()) {
194     auto bitWidth = type.getIntOrFloatBitWidth();
195     // According to the SPIR-V spec:
196     // "There is no physical size or bit pattern defined for values with boolean
197     // type. If they are stored (in conjunction with OpVariable), they can only
198     // be used with logical addressing operations, not physical, and only with
199     // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
200     // Private, Function, Input, and Output."
201     if (bitWidth == 1)
202       return llvm::None;
203     return bitWidth / 8;
204   }
205 
206   if (auto vecType = type.dyn_cast<VectorType>()) {
207     auto elementSize = getTypeNumBytes(options, vecType.getElementType());
208     if (!elementSize)
209       return llvm::None;
210     return vecType.getNumElements() * *elementSize;
211   }
212 
213   if (auto memRefType = type.dyn_cast<MemRefType>()) {
214     // TODO: Layout should also be controlled by the ABI attributes. For now
215     // using the layout from MemRef.
216     int64_t offset;
217     SmallVector<int64_t, 4> strides;
218     if (!memRefType.hasStaticShape() ||
219         failed(getStridesAndOffset(memRefType, strides, offset)))
220       return llvm::None;
221 
222     // To get the size of the memref object in memory, the total size is the
223     // max(stride * dimension-size) computed for all dimensions times the size
224     // of the element.
225     auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
226     if (!elementSize)
227       return llvm::None;
228 
229     if (memRefType.getRank() == 0)
230       return elementSize;
231 
232     auto dims = memRefType.getShape();
233     if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
234         offset == MemRefType::getDynamicStrideOrOffset() ||
235         llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
236       return llvm::None;
237 
238     int64_t memrefSize = -1;
239     for (const auto &shape : enumerate(dims))
240       memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
241 
242     return (offset + memrefSize) * *elementSize;
243   }
244 
245   if (auto tensorType = type.dyn_cast<TensorType>()) {
246     if (!tensorType.hasStaticShape())
247       return llvm::None;
248 
249     auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
250     if (!elementSize)
251       return llvm::None;
252 
253     int64_t size = *elementSize;
254     for (auto shape : tensorType.getShape())
255       size *= shape;
256 
257     return size;
258   }
259 
260   // TODO: Add size computation for other types.
261   return llvm::None;
262 }
263 
264 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
convertScalarType(const spirv::TargetEnv & targetEnv,const SPIRVTypeConverter::Options & options,spirv::ScalarType type,Optional<spirv::StorageClass> storageClass={})265 static Type convertScalarType(const spirv::TargetEnv &targetEnv,
266                               const SPIRVTypeConverter::Options &options,
267                               spirv::ScalarType type,
268                               Optional<spirv::StorageClass> storageClass = {}) {
269   // Get extension and capability requirements for the given type.
270   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
271   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
272   type.getExtensions(extensions, storageClass);
273   type.getCapabilities(capabilities, storageClass);
274 
275   // If all requirements are met, then we can accept this type as-is.
276   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
277       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
278     return type;
279 
280   // Otherwise we need to adjust the type, which really means adjusting the
281   // bitwidth given this is a scalar type.
282 
283   if (!options.emulateNon32BitScalarTypes)
284     return nullptr;
285 
286   if (auto floatType = type.dyn_cast<FloatType>()) {
287     LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
288     return Builder(targetEnv.getContext()).getF32Type();
289   }
290 
291   auto intType = type.cast<IntegerType>();
292   LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
293   return IntegerType::get(targetEnv.getContext(), /*width=*/32,
294                           intType.getSignedness());
295 }
296 
297 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
convertVectorType(const spirv::TargetEnv & targetEnv,const SPIRVTypeConverter::Options & options,VectorType type,Optional<spirv::StorageClass> storageClass={})298 static Type convertVectorType(const spirv::TargetEnv &targetEnv,
299                               const SPIRVTypeConverter::Options &options,
300                               VectorType type,
301                               Optional<spirv::StorageClass> storageClass = {}) {
302   if (type.getRank() == 1 && type.getNumElements() == 1)
303     return type.getElementType();
304 
305   if (!spirv::CompositeType::isValid(type)) {
306     // TODO: Vector types with more than four elements can be translated into
307     // array types.
308     LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
309     return nullptr;
310   }
311 
312   // Get extension and capability requirements for the given type.
313   SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
314   SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
315   type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
316   type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
317 
318   // If all requirements are met, then we can accept this type as-is.
319   if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
320       succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
321     return type;
322 
323   auto elementType = convertScalarType(
324       targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
325       storageClass);
326   if (elementType)
327     return VectorType::get(type.getShape(), elementType);
328   return nullptr;
329 }
330 
331 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
332 ///
333 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
334 /// create composite constants with OpConstantComposite to embed relative large
335 /// constant values and use OpCompositeExtract and OpCompositeInsert to
336 /// manipulate, like what we do for vectors.
convertTensorType(const spirv::TargetEnv & targetEnv,const SPIRVTypeConverter::Options & options,TensorType type)337 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
338                               const SPIRVTypeConverter::Options &options,
339                               TensorType type) {
340   // TODO: Handle dynamic shapes.
341   if (!type.hasStaticShape()) {
342     LLVM_DEBUG(llvm::dbgs()
343                << type << " illegal: dynamic shape unimplemented\n");
344     return nullptr;
345   }
346 
347   auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
348   if (!scalarType) {
349     LLVM_DEBUG(llvm::dbgs()
350                << type << " illegal: cannot convert non-scalar element type\n");
351     return nullptr;
352   }
353 
354   Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
355   Optional<int64_t> tensorSize = getTypeNumBytes(options, type);
356   if (!scalarSize || !tensorSize) {
357     LLVM_DEBUG(llvm::dbgs()
358                << type << " illegal: cannot deduce element count\n");
359     return nullptr;
360   }
361 
362   auto arrayElemCount = *tensorSize / *scalarSize;
363   auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
364   if (!arrayElemType)
365     return nullptr;
366   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
367   if (!arrayElemSize) {
368     LLVM_DEBUG(llvm::dbgs()
369                << type << " illegal: cannot deduce converted element size\n");
370     return nullptr;
371   }
372 
373   return spirv::ArrayType::get(arrayElemType, arrayElemCount);
374 }
375 
convertBoolMemrefType(const spirv::TargetEnv & targetEnv,const SPIRVTypeConverter::Options & options,MemRefType type)376 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
377                                   const SPIRVTypeConverter::Options &options,
378                                   MemRefType type) {
379   Optional<spirv::StorageClass> storageClass =
380       SPIRVTypeConverter::getStorageClassForMemorySpace(
381           type.getMemorySpaceAsInt());
382   if (!storageClass) {
383     LLVM_DEBUG(llvm::dbgs()
384                << type << " illegal: cannot convert memory space\n");
385     return nullptr;
386   }
387 
388   unsigned numBoolBits = options.boolNumBits;
389   if (numBoolBits != 8) {
390     LLVM_DEBUG(llvm::dbgs()
391                << "using non-8-bit storage for bool types unimplemented");
392     return nullptr;
393   }
394   auto elementType = IntegerType::get(type.getContext(), numBoolBits)
395                          .dyn_cast<spirv::ScalarType>();
396   if (!elementType)
397     return nullptr;
398   Type arrayElemType =
399       convertScalarType(targetEnv, options, elementType, storageClass);
400   if (!arrayElemType)
401     return nullptr;
402   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
403   if (!arrayElemSize) {
404     LLVM_DEBUG(llvm::dbgs()
405                << type << " illegal: cannot deduce converted element size\n");
406     return nullptr;
407   }
408 
409   if (!type.hasStaticShape()) {
410     int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
411     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
412     return wrapInStructAndGetPointer(arrayType, *storageClass);
413   }
414 
415   int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
416   auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
417   int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
418   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
419 
420   return wrapInStructAndGetPointer(arrayType, *storageClass);
421 }
422 
convertMemrefType(const spirv::TargetEnv & targetEnv,const SPIRVTypeConverter::Options & options,MemRefType type)423 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
424                               const SPIRVTypeConverter::Options &options,
425                               MemRefType type) {
426   if (type.getElementType().isa<IntegerType>() &&
427       type.getElementTypeBitWidth() == 1) {
428     return convertBoolMemrefType(targetEnv, options, type);
429   }
430 
431   Optional<spirv::StorageClass> storageClass =
432       SPIRVTypeConverter::getStorageClassForMemorySpace(
433           type.getMemorySpaceAsInt());
434   if (!storageClass) {
435     LLVM_DEBUG(llvm::dbgs()
436                << type << " illegal: cannot convert memory space\n");
437     return nullptr;
438   }
439 
440   Type arrayElemType;
441   Type elementType = type.getElementType();
442   if (auto vecType = elementType.dyn_cast<VectorType>()) {
443     arrayElemType =
444         convertVectorType(targetEnv, options, vecType, storageClass);
445   } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
446     arrayElemType =
447         convertScalarType(targetEnv, options, scalarType, storageClass);
448   } else {
449     LLVM_DEBUG(
450         llvm::dbgs()
451         << type
452         << " unhandled: can only convert scalar or vector element type\n");
453     return nullptr;
454   }
455   if (!arrayElemType)
456     return nullptr;
457 
458   Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
459   if (!arrayElemSize) {
460     LLVM_DEBUG(llvm::dbgs()
461                << type << " illegal: cannot deduce converted element size\n");
462     return nullptr;
463   }
464 
465   if (!type.hasStaticShape()) {
466     int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
467     auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
468     return wrapInStructAndGetPointer(arrayType, *storageClass);
469   }
470 
471   Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
472   if (!memrefSize) {
473     LLVM_DEBUG(llvm::dbgs()
474                << type << " illegal: cannot deduce element count\n");
475     return nullptr;
476   }
477 
478   auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
479   int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
480   auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
481 
482   return wrapInStructAndGetPointer(arrayType, *storageClass);
483 }
484 
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,Options options)485 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
486                                        Options options)
487     : targetEnv(targetAttr), options(options) {
488   // Add conversions. The order matters here: later ones will be tried earlier.
489 
490   // Allow all SPIR-V dialect specific types. This assumes all builtin types
491   // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
492   // were tried before.
493   //
494   // TODO: this assumes that the SPIR-V types are valid to use in
495   // the given target environment, which should be the case if the whole
496   // pipeline is driven by the same target environment. Still, we probably still
497   // want to validate and convert to be safe.
498   addConversion([](spirv::SPIRVType type) { return type; });
499 
500   addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
501 
502   addConversion([this](IntegerType intType) -> Optional<Type> {
503     if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
504       return convertScalarType(this->targetEnv, this->options, scalarType);
505     return Type();
506   });
507 
508   addConversion([this](FloatType floatType) -> Optional<Type> {
509     if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
510       return convertScalarType(this->targetEnv, this->options, scalarType);
511     return Type();
512   });
513 
514   addConversion([this](VectorType vectorType) {
515     return convertVectorType(this->targetEnv, this->options, vectorType);
516   });
517 
518   addConversion([this](TensorType tensorType) {
519     return convertTensorType(this->targetEnv, this->options, tensorType);
520   });
521 
522   addConversion([this](MemRefType memRefType) {
523     return convertMemrefType(this->targetEnv, this->options, memRefType);
524   });
525 }
526 
527 //===----------------------------------------------------------------------===//
528 // func::FuncOp Conversion Patterns
529 //===----------------------------------------------------------------------===//
530 
531 namespace {
532 /// A pattern for rewriting function signature to convert arguments of functions
533 /// to be of valid SPIR-V types.
534 class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
535 public:
536   using OpConversionPattern<func::FuncOp>::OpConversionPattern;
537 
538   LogicalResult
539   matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
540                   ConversionPatternRewriter &rewriter) const override;
541 };
542 } // namespace
543 
544 LogicalResult
matchAndRewrite(func::FuncOp funcOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const545 FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
546                                   ConversionPatternRewriter &rewriter) const {
547   auto fnType = funcOp.getFunctionType();
548   if (fnType.getNumResults() > 1)
549     return failure();
550 
551   TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
552   for (const auto &argType : enumerate(fnType.getInputs())) {
553     auto convertedType = getTypeConverter()->convertType(argType.value());
554     if (!convertedType)
555       return failure();
556     signatureConverter.addInputs(argType.index(), convertedType);
557   }
558 
559   Type resultType;
560   if (fnType.getNumResults() == 1) {
561     resultType = getTypeConverter()->convertType(fnType.getResult(0));
562     if (!resultType)
563       return failure();
564   }
565 
566   // Create the converted spv.func op.
567   auto newFuncOp = rewriter.create<spirv::FuncOp>(
568       funcOp.getLoc(), funcOp.getName(),
569       rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
570                                resultType ? TypeRange(resultType)
571                                           : TypeRange()));
572 
573   // Copy over all attributes other than the function name and type.
574   for (const auto &namedAttr : funcOp->getAttrs()) {
575     if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
576         namedAttr.getName() != SymbolTable::getSymbolAttrName())
577       newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
578   }
579 
580   rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
581                               newFuncOp.end());
582   if (failed(rewriter.convertRegionTypes(
583           &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
584     return failure();
585   rewriter.eraseOp(funcOp);
586   return success();
587 }
588 
populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)589 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
590                                               RewritePatternSet &patterns) {
591   patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // Builtin Variables
596 //===----------------------------------------------------------------------===//
597 
getBuiltinVariable(Block & body,spirv::BuiltIn builtin)598 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
599                                                   spirv::BuiltIn builtin) {
600   // Look through all global variables in the given `body` block and check if
601   // there is a spv.GlobalVariable that has the same `builtin` attribute.
602   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
603     if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
604             spirv::SPIRVDialect::getAttributeName(
605                 spirv::Decoration::BuiltIn))) {
606       auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
607       if (varBuiltIn && *varBuiltIn == builtin) {
608         return varOp;
609       }
610     }
611   }
612   return nullptr;
613 }
614 
615 /// Gets name of global variable for a builtin.
getBuiltinVarName(spirv::BuiltIn builtin)616 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
617   return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
618 }
619 
620 /// Gets or inserts a global variable for a builtin within `body` block.
621 static spirv::GlobalVariableOp
getOrInsertBuiltinVariable(Block & body,Location loc,spirv::BuiltIn builtin,Type integerType,OpBuilder & builder)622 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
623                            Type integerType, OpBuilder &builder) {
624   if (auto varOp = getBuiltinVariable(body, builtin))
625     return varOp;
626 
627   OpBuilder::InsertionGuard guard(builder);
628   builder.setInsertionPointToStart(&body);
629 
630   spirv::GlobalVariableOp newVarOp;
631   switch (builtin) {
632   case spirv::BuiltIn::NumWorkgroups:
633   case spirv::BuiltIn::WorkgroupSize:
634   case spirv::BuiltIn::WorkgroupId:
635   case spirv::BuiltIn::LocalInvocationId:
636   case spirv::BuiltIn::GlobalInvocationId: {
637     auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
638                                            spirv::StorageClass::Input);
639     std::string name = getBuiltinVarName(builtin);
640     newVarOp =
641         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
642     break;
643   }
644   case spirv::BuiltIn::SubgroupId:
645   case spirv::BuiltIn::NumSubgroups:
646   case spirv::BuiltIn::SubgroupSize: {
647     auto ptrType =
648         spirv::PointerType::get(integerType, spirv::StorageClass::Input);
649     std::string name = getBuiltinVarName(builtin);
650     newVarOp =
651         builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
652     break;
653   }
654   default:
655     emitError(loc, "unimplemented builtin variable generation for ")
656         << stringifyBuiltIn(builtin);
657   }
658   return newVarOp;
659 }
660 
getBuiltinVariableValue(Operation * op,spirv::BuiltIn builtin,Type integerType,OpBuilder & builder)661 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
662                                            spirv::BuiltIn builtin,
663                                            Type integerType,
664                                            OpBuilder &builder) {
665   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
666   if (!parent) {
667     op->emitError("expected operation to be within a module-like op");
668     return nullptr;
669   }
670 
671   spirv::GlobalVariableOp varOp =
672       getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
673                                  builtin, integerType, builder);
674   Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
675   return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
676 }
677 
678 //===----------------------------------------------------------------------===//
679 // Push constant storage
680 //===----------------------------------------------------------------------===//
681 
682 /// Returns the pointer type for the push constant storage containing
683 /// `elementCount` 32-bit integer values.
getPushConstantStorageType(unsigned elementCount,Builder & builder,Type indexType)684 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
685                                                      Builder &builder,
686                                                      Type indexType) {
687   auto arrayType = spirv::ArrayType::get(indexType, elementCount,
688                                          /*stride=*/4);
689   auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
690   return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
691 }
692 
693 /// Returns the push constant varible containing `elementCount` 32-bit integer
694 /// values in `body`. Returns null op if such an op does not exit.
getPushConstantVariable(Block & body,unsigned elementCount)695 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
696                                                        unsigned elementCount) {
697   for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
698     auto ptrType = varOp.type().dyn_cast<spirv::PointerType>();
699     if (!ptrType)
700       continue;
701 
702     // Note that Vulkan requires "There must be no more than one push constant
703     // block statically used per shader entry point." So we should always reuse
704     // the existing one.
705     if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
706       auto numElements = ptrType.getPointeeType()
707                              .cast<spirv::StructType>()
708                              .getElementType(0)
709                              .cast<spirv::ArrayType>()
710                              .getNumElements();
711       if (numElements == elementCount)
712         return varOp;
713     }
714   }
715   return nullptr;
716 }
717 
718 /// Gets or inserts a global variable for push constant storage containing
719 /// `elementCount` 32-bit integer values in `block`.
720 static spirv::GlobalVariableOp
getOrInsertPushConstantVariable(Location loc,Block & block,unsigned elementCount,OpBuilder & b,Type indexType)721 getOrInsertPushConstantVariable(Location loc, Block &block,
722                                 unsigned elementCount, OpBuilder &b,
723                                 Type indexType) {
724   if (auto varOp = getPushConstantVariable(block, elementCount))
725     return varOp;
726 
727   auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
728   auto type = getPushConstantStorageType(elementCount, builder, indexType);
729   const char *name = "__push_constant_var__";
730   return builder.create<spirv::GlobalVariableOp>(loc, type, name,
731                                                  /*initializer=*/nullptr);
732 }
733 
getPushConstantValue(Operation * op,unsigned elementCount,unsigned offset,Type integerType,OpBuilder & builder)734 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
735                                   unsigned offset, Type integerType,
736                                   OpBuilder &builder) {
737   Location loc = op->getLoc();
738   Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
739   if (!parent) {
740     op->emitError("expected operation to be within a module-like op");
741     return nullptr;
742   }
743 
744   spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
745       loc, parent->getRegion(0).front(), elementCount, builder, integerType);
746 
747   Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
748   Value offsetOp = builder.create<spirv::ConstantOp>(
749       loc, integerType, builder.getI32IntegerAttr(offset));
750   auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
751   auto acOp = builder.create<spirv::AccessChainOp>(
752       loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
753   return builder.create<spirv::LoadOp>(loc, acOp);
754 }
755 
756 //===----------------------------------------------------------------------===//
757 // Index calculation
758 //===----------------------------------------------------------------------===//
759 
linearizeIndex(ValueRange indices,ArrayRef<int64_t> strides,int64_t offset,Type integerType,Location loc,OpBuilder & builder)760 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
761                                   int64_t offset, Type integerType,
762                                   Location loc, OpBuilder &builder) {
763   assert(indices.size() == strides.size() &&
764          "must provide indices for all dimensions");
765 
766   // TODO: Consider moving to use affine.apply and patterns converting
767   // affine.apply to standard ops. This needs converting to SPIR-V passes to be
768   // broken down into progressive small steps so we can have intermediate steps
769   // using other dialects. At the moment SPIR-V is the final sink.
770 
771   Value linearizedIndex = builder.create<spirv::ConstantOp>(
772       loc, integerType, IntegerAttr::get(integerType, offset));
773   for (const auto &index : llvm::enumerate(indices)) {
774     Value strideVal = builder.create<spirv::ConstantOp>(
775         loc, integerType,
776         IntegerAttr::get(integerType, strides[index.index()]));
777     Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
778     linearizedIndex =
779         builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
780   }
781   return linearizedIndex;
782 }
783 
getElementPtr(SPIRVTypeConverter & typeConverter,MemRefType baseType,Value basePtr,ValueRange indices,Location loc,OpBuilder & builder)784 spirv::AccessChainOp mlir::spirv::getElementPtr(
785     SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
786     ValueRange indices, Location loc, OpBuilder &builder) {
787   // Get base and offset of the MemRefType and verify they are static.
788 
789   int64_t offset;
790   SmallVector<int64_t, 4> strides;
791   if (failed(getStridesAndOffset(baseType, strides, offset)) ||
792       llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
793       offset == MemRefType::getDynamicStrideOrOffset()) {
794     return nullptr;
795   }
796 
797   auto indexType = typeConverter.getIndexType();
798 
799   SmallVector<Value, 2> linearizedIndices;
800   auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
801 
802   // Add a '0' at the start to index into the struct.
803   linearizedIndices.push_back(zero);
804 
805   if (baseType.getRank() == 0) {
806     linearizedIndices.push_back(zero);
807   } else {
808     linearizedIndices.push_back(
809         linearizeIndex(indices, strides, offset, indexType, loc, builder));
810   }
811   return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
812 }
813 
814 //===----------------------------------------------------------------------===//
815 // SPIR-V ConversionTarget
816 //===----------------------------------------------------------------------===//
817 
818 std::unique_ptr<SPIRVConversionTarget>
get(spirv::TargetEnvAttr targetAttr)819 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
820   std::unique_ptr<SPIRVConversionTarget> target(
821       // std::make_unique does not work here because the constructor is private.
822       new SPIRVConversionTarget(targetAttr));
823   SPIRVConversionTarget *targetPtr = target.get();
824   target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
825       // We need to capture the raw pointer here because it is stable:
826       // target will be destroyed once this function is returned.
827       [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
828   return target;
829 }
830 
SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)831 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
832     : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
833 
isLegalOp(Operation * op)834 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
835   // Make sure this op is available at the given version. Ops not implementing
836   // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
837   // SPIR-V versions.
838   if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
839     Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
840     if (minVersion && *minVersion > this->targetEnv.getVersion()) {
841       LLVM_DEBUG(llvm::dbgs()
842                  << op->getName() << " illegal: requiring min version "
843                  << spirv::stringifyVersion(*minVersion) << "\n");
844       return false;
845     }
846   }
847   if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
848     Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
849     if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
850       LLVM_DEBUG(llvm::dbgs()
851                  << op->getName() << " illegal: requiring max version "
852                  << spirv::stringifyVersion(*maxVersion) << "\n");
853       return false;
854     }
855   }
856 
857   // Make sure this op's required extensions are allowed to use. Ops not
858   // implementing QueryExtensionInterface do not require extensions to be
859   // available.
860   if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
861     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
862                                           extensions.getExtensions())))
863       return false;
864 
865   // Make sure this op's required extensions are allowed to use. Ops not
866   // implementing QueryCapabilityInterface do not require capabilities to be
867   // available.
868   if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
869     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
870                                            capabilities.getCapabilities())))
871       return false;
872 
873   SmallVector<Type, 4> valueTypes;
874   valueTypes.append(op->operand_type_begin(), op->operand_type_end());
875   valueTypes.append(op->result_type_begin(), op->result_type_end());
876 
877   // Ensure that all types have been converted to SPIRV types.
878   if (llvm::any_of(valueTypes,
879                    [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
880     return false;
881 
882   // Special treatment for global variables, whose type requirements are
883   // conveyed by type attributes.
884   if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
885     valueTypes.push_back(globalVar.type());
886 
887   // Make sure the op's operands/results use types that are allowed by the
888   // target environment.
889   SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
890   SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
891   for (Type valueType : valueTypes) {
892     typeExtensions.clear();
893     valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
894     if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
895                                           typeExtensions)))
896       return false;
897 
898     typeCapabilities.clear();
899     valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
900     if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
901                                            typeCapabilities)))
902       return false;
903   }
904 
905   return true;
906 }
907