1 //===- SPIRVConversion.cpp - SPIR-V Conversion Utilities ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities used to lower to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Transforms/DialectConversion.h"
18 #include "llvm/ADT/Sequence.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/Support/Debug.h"
21
22 #include <functional>
23
24 #define DEBUG_TYPE "mlir-spirv-conversion"
25
26 using namespace mlir;
27
28 //===----------------------------------------------------------------------===//
29 // Utility functions
30 //===----------------------------------------------------------------------===//
31
32 /// Checks that `candidates` extension requirements are possible to be satisfied
33 /// with the given `targetEnv`.
34 ///
35 /// `candidates` is a vector of vector for extension requirements following
36 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
37 /// convention.
38 template <typename LabelT>
checkExtensionRequirements(LabelT label,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::ExtensionArrayRefVector & candidates)39 static LogicalResult checkExtensionRequirements(
40 LabelT label, const spirv::TargetEnv &targetEnv,
41 const spirv::SPIRVType::ExtensionArrayRefVector &candidates) {
42 for (const auto &ors : candidates) {
43 if (targetEnv.allows(ors))
44 continue;
45
46 LLVM_DEBUG({
47 SmallVector<StringRef> extStrings;
48 for (spirv::Extension ext : ors)
49 extStrings.push_back(spirv::stringifyExtension(ext));
50
51 llvm::dbgs() << label << " illegal: requires at least one extension in ["
52 << llvm::join(extStrings, ", ")
53 << "] but none allowed in target environment\n";
54 });
55 return failure();
56 }
57 return success();
58 }
59
60 /// Checks that `candidates`capability requirements are possible to be satisfied
61 /// with the given `isAllowedFn`.
62 ///
63 /// `candidates` is a vector of vector for capability requirements following
64 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
65 /// convention.
66 template <typename LabelT>
checkCapabilityRequirements(LabelT label,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::CapabilityArrayRefVector & candidates)67 static LogicalResult checkCapabilityRequirements(
68 LabelT label, const spirv::TargetEnv &targetEnv,
69 const spirv::SPIRVType::CapabilityArrayRefVector &candidates) {
70 for (const auto &ors : candidates) {
71 if (targetEnv.allows(ors))
72 continue;
73
74 LLVM_DEBUG({
75 SmallVector<StringRef> capStrings;
76 for (spirv::Capability cap : ors)
77 capStrings.push_back(spirv::stringifyCapability(cap));
78
79 llvm::dbgs() << label << " illegal: requires at least one capability in ["
80 << llvm::join(capStrings, ", ")
81 << "] but none allowed in target environment\n";
82 });
83 return failure();
84 }
85 return success();
86 }
87
88 /// Returns true if the given `storageClass` needs explicit layout when used in
89 /// Shader environments.
needsExplicitLayout(spirv::StorageClass storageClass)90 static bool needsExplicitLayout(spirv::StorageClass storageClass) {
91 switch (storageClass) {
92 case spirv::StorageClass::PhysicalStorageBuffer:
93 case spirv::StorageClass::PushConstant:
94 case spirv::StorageClass::StorageBuffer:
95 case spirv::StorageClass::Uniform:
96 return true;
97 default:
98 return false;
99 }
100 }
101
102 /// Wraps the given `elementType` in a struct and gets the pointer to the
103 /// struct. This is used to satisfy Vulkan interface requirements.
104 static spirv::PointerType
wrapInStructAndGetPointer(Type elementType,spirv::StorageClass storageClass)105 wrapInStructAndGetPointer(Type elementType, spirv::StorageClass storageClass) {
106 auto structType = needsExplicitLayout(storageClass)
107 ? spirv::StructType::get(elementType, /*offsetInfo=*/0)
108 : spirv::StructType::get(elementType);
109 return spirv::PointerType::get(structType, storageClass);
110 }
111
112 //===----------------------------------------------------------------------===//
113 // Type Conversion
114 //===----------------------------------------------------------------------===//
115
getIndexType() const116 Type SPIRVTypeConverter::getIndexType() const {
117 return IntegerType::get(getContext(), options.use64bitIndex ? 64 : 32);
118 }
119
120 /// Mapping between SPIR-V storage classes to memref memory spaces.
121 ///
122 /// Note: memref does not have a defined semantics for each memory space; it
123 /// depends on the context where it is used. There are no particular reasons
124 /// behind the number assignments; we try to follow NVVM conventions and largely
125 /// give common storage classes a smaller number. The hope is use symbolic
126 /// memory space representation eventually after memref supports it.
127 // TODO: swap Generic and StorageBuffer assignment to be more akin
128 // to NVVM.
129 #define STORAGE_SPACE_MAP_LIST(MAP_FN) \
130 MAP_FN(spirv::StorageClass::Generic, 1) \
131 MAP_FN(spirv::StorageClass::StorageBuffer, 0) \
132 MAP_FN(spirv::StorageClass::Workgroup, 3) \
133 MAP_FN(spirv::StorageClass::Uniform, 4) \
134 MAP_FN(spirv::StorageClass::Private, 5) \
135 MAP_FN(spirv::StorageClass::Function, 6) \
136 MAP_FN(spirv::StorageClass::PushConstant, 7) \
137 MAP_FN(spirv::StorageClass::UniformConstant, 8) \
138 MAP_FN(spirv::StorageClass::Input, 9) \
139 MAP_FN(spirv::StorageClass::Output, 10) \
140 MAP_FN(spirv::StorageClass::CrossWorkgroup, 11) \
141 MAP_FN(spirv::StorageClass::AtomicCounter, 12) \
142 MAP_FN(spirv::StorageClass::Image, 13) \
143 MAP_FN(spirv::StorageClass::CallableDataKHR, 14) \
144 MAP_FN(spirv::StorageClass::IncomingCallableDataKHR, 15) \
145 MAP_FN(spirv::StorageClass::RayPayloadKHR, 16) \
146 MAP_FN(spirv::StorageClass::HitAttributeKHR, 17) \
147 MAP_FN(spirv::StorageClass::IncomingRayPayloadKHR, 18) \
148 MAP_FN(spirv::StorageClass::ShaderRecordBufferKHR, 19) \
149 MAP_FN(spirv::StorageClass::PhysicalStorageBuffer, 20) \
150 MAP_FN(spirv::StorageClass::CodeSectionINTEL, 21) \
151 MAP_FN(spirv::StorageClass::DeviceOnlyINTEL, 22) \
152 MAP_FN(spirv::StorageClass::HostOnlyINTEL, 23)
153
154 unsigned
getMemorySpaceForStorageClass(spirv::StorageClass storage)155 SPIRVTypeConverter::getMemorySpaceForStorageClass(spirv::StorageClass storage) {
156 #define STORAGE_SPACE_MAP_FN(storage, space) \
157 case storage: \
158 return space;
159
160 switch (storage) { STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN) }
161 #undef STORAGE_SPACE_MAP_FN
162 llvm_unreachable("unhandled storage class!");
163 }
164
165 Optional<spirv::StorageClass>
getStorageClassForMemorySpace(unsigned space)166 SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
167 #define STORAGE_SPACE_MAP_FN(storage, space) \
168 case space: \
169 return storage;
170
171 switch (space) {
172 STORAGE_SPACE_MAP_LIST(STORAGE_SPACE_MAP_FN)
173 default:
174 return llvm::None;
175 }
176 #undef STORAGE_SPACE_MAP_FN
177 }
178
getOptions() const179 const SPIRVTypeConverter::Options &SPIRVTypeConverter::getOptions() const {
180 return options;
181 }
182
getContext() const183 MLIRContext *SPIRVTypeConverter::getContext() const {
184 return targetEnv.getAttr().getContext();
185 }
186
187 #undef STORAGE_SPACE_MAP_LIST
188
189 // TODO: This is a utility function that should probably be exposed by the
190 // SPIR-V dialect. Keeping it local till the use case arises.
191 static Optional<int64_t>
getTypeNumBytes(const SPIRVTypeConverter::Options & options,Type type)192 getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
193 if (type.isa<spirv::ScalarType>()) {
194 auto bitWidth = type.getIntOrFloatBitWidth();
195 // According to the SPIR-V spec:
196 // "There is no physical size or bit pattern defined for values with boolean
197 // type. If they are stored (in conjunction with OpVariable), they can only
198 // be used with logical addressing operations, not physical, and only with
199 // non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
200 // Private, Function, Input, and Output."
201 if (bitWidth == 1)
202 return llvm::None;
203 return bitWidth / 8;
204 }
205
206 if (auto vecType = type.dyn_cast<VectorType>()) {
207 auto elementSize = getTypeNumBytes(options, vecType.getElementType());
208 if (!elementSize)
209 return llvm::None;
210 return vecType.getNumElements() * *elementSize;
211 }
212
213 if (auto memRefType = type.dyn_cast<MemRefType>()) {
214 // TODO: Layout should also be controlled by the ABI attributes. For now
215 // using the layout from MemRef.
216 int64_t offset;
217 SmallVector<int64_t, 4> strides;
218 if (!memRefType.hasStaticShape() ||
219 failed(getStridesAndOffset(memRefType, strides, offset)))
220 return llvm::None;
221
222 // To get the size of the memref object in memory, the total size is the
223 // max(stride * dimension-size) computed for all dimensions times the size
224 // of the element.
225 auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
226 if (!elementSize)
227 return llvm::None;
228
229 if (memRefType.getRank() == 0)
230 return elementSize;
231
232 auto dims = memRefType.getShape();
233 if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
234 offset == MemRefType::getDynamicStrideOrOffset() ||
235 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
236 return llvm::None;
237
238 int64_t memrefSize = -1;
239 for (const auto &shape : enumerate(dims))
240 memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
241
242 return (offset + memrefSize) * *elementSize;
243 }
244
245 if (auto tensorType = type.dyn_cast<TensorType>()) {
246 if (!tensorType.hasStaticShape())
247 return llvm::None;
248
249 auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
250 if (!elementSize)
251 return llvm::None;
252
253 int64_t size = *elementSize;
254 for (auto shape : tensorType.getShape())
255 size *= shape;
256
257 return size;
258 }
259
260 // TODO: Add size computation for other types.
261 return llvm::None;
262 }
263
264 /// Converts a scalar `type` to a suitable type under the given `targetEnv`.
convertScalarType(const spirv::TargetEnv & targetEnv,const SPIRVTypeConverter::Options & options,spirv::ScalarType type,Optional<spirv::StorageClass> storageClass={})265 static Type convertScalarType(const spirv::TargetEnv &targetEnv,
266 const SPIRVTypeConverter::Options &options,
267 spirv::ScalarType type,
268 Optional<spirv::StorageClass> storageClass = {}) {
269 // Get extension and capability requirements for the given type.
270 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
271 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
272 type.getExtensions(extensions, storageClass);
273 type.getCapabilities(capabilities, storageClass);
274
275 // If all requirements are met, then we can accept this type as-is.
276 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
277 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
278 return type;
279
280 // Otherwise we need to adjust the type, which really means adjusting the
281 // bitwidth given this is a scalar type.
282
283 if (!options.emulateNon32BitScalarTypes)
284 return nullptr;
285
286 if (auto floatType = type.dyn_cast<FloatType>()) {
287 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
288 return Builder(targetEnv.getContext()).getF32Type();
289 }
290
291 auto intType = type.cast<IntegerType>();
292 LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
293 return IntegerType::get(targetEnv.getContext(), /*width=*/32,
294 intType.getSignedness());
295 }
296
297 /// Converts a vector `type` to a suitable type under the given `targetEnv`.
convertVectorType(const spirv::TargetEnv & targetEnv,const SPIRVTypeConverter::Options & options,VectorType type,Optional<spirv::StorageClass> storageClass={})298 static Type convertVectorType(const spirv::TargetEnv &targetEnv,
299 const SPIRVTypeConverter::Options &options,
300 VectorType type,
301 Optional<spirv::StorageClass> storageClass = {}) {
302 if (type.getRank() == 1 && type.getNumElements() == 1)
303 return type.getElementType();
304
305 if (!spirv::CompositeType::isValid(type)) {
306 // TODO: Vector types with more than four elements can be translated into
307 // array types.
308 LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
309 return nullptr;
310 }
311
312 // Get extension and capability requirements for the given type.
313 SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
314 SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
315 type.cast<spirv::CompositeType>().getExtensions(extensions, storageClass);
316 type.cast<spirv::CompositeType>().getCapabilities(capabilities, storageClass);
317
318 // If all requirements are met, then we can accept this type as-is.
319 if (succeeded(checkCapabilityRequirements(type, targetEnv, capabilities)) &&
320 succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
321 return type;
322
323 auto elementType = convertScalarType(
324 targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
325 storageClass);
326 if (elementType)
327 return VectorType::get(type.getShape(), elementType);
328 return nullptr;
329 }
330
331 /// Converts a tensor `type` to a suitable type under the given `targetEnv`.
332 ///
333 /// Note that this is mainly for lowering constant tensors. In SPIR-V one can
334 /// create composite constants with OpConstantComposite to embed relative large
335 /// constant values and use OpCompositeExtract and OpCompositeInsert to
336 /// manipulate, like what we do for vectors.
convertTensorType(const spirv::TargetEnv & targetEnv,const SPIRVTypeConverter::Options & options,TensorType type)337 static Type convertTensorType(const spirv::TargetEnv &targetEnv,
338 const SPIRVTypeConverter::Options &options,
339 TensorType type) {
340 // TODO: Handle dynamic shapes.
341 if (!type.hasStaticShape()) {
342 LLVM_DEBUG(llvm::dbgs()
343 << type << " illegal: dynamic shape unimplemented\n");
344 return nullptr;
345 }
346
347 auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
348 if (!scalarType) {
349 LLVM_DEBUG(llvm::dbgs()
350 << type << " illegal: cannot convert non-scalar element type\n");
351 return nullptr;
352 }
353
354 Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
355 Optional<int64_t> tensorSize = getTypeNumBytes(options, type);
356 if (!scalarSize || !tensorSize) {
357 LLVM_DEBUG(llvm::dbgs()
358 << type << " illegal: cannot deduce element count\n");
359 return nullptr;
360 }
361
362 auto arrayElemCount = *tensorSize / *scalarSize;
363 auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
364 if (!arrayElemType)
365 return nullptr;
366 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
367 if (!arrayElemSize) {
368 LLVM_DEBUG(llvm::dbgs()
369 << type << " illegal: cannot deduce converted element size\n");
370 return nullptr;
371 }
372
373 return spirv::ArrayType::get(arrayElemType, arrayElemCount);
374 }
375
convertBoolMemrefType(const spirv::TargetEnv & targetEnv,const SPIRVTypeConverter::Options & options,MemRefType type)376 static Type convertBoolMemrefType(const spirv::TargetEnv &targetEnv,
377 const SPIRVTypeConverter::Options &options,
378 MemRefType type) {
379 Optional<spirv::StorageClass> storageClass =
380 SPIRVTypeConverter::getStorageClassForMemorySpace(
381 type.getMemorySpaceAsInt());
382 if (!storageClass) {
383 LLVM_DEBUG(llvm::dbgs()
384 << type << " illegal: cannot convert memory space\n");
385 return nullptr;
386 }
387
388 unsigned numBoolBits = options.boolNumBits;
389 if (numBoolBits != 8) {
390 LLVM_DEBUG(llvm::dbgs()
391 << "using non-8-bit storage for bool types unimplemented");
392 return nullptr;
393 }
394 auto elementType = IntegerType::get(type.getContext(), numBoolBits)
395 .dyn_cast<spirv::ScalarType>();
396 if (!elementType)
397 return nullptr;
398 Type arrayElemType =
399 convertScalarType(targetEnv, options, elementType, storageClass);
400 if (!arrayElemType)
401 return nullptr;
402 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
403 if (!arrayElemSize) {
404 LLVM_DEBUG(llvm::dbgs()
405 << type << " illegal: cannot deduce converted element size\n");
406 return nullptr;
407 }
408
409 if (!type.hasStaticShape()) {
410 int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
411 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
412 return wrapInStructAndGetPointer(arrayType, *storageClass);
413 }
414
415 int64_t memrefSize = (type.getNumElements() * numBoolBits + 7) / 8;
416 auto arrayElemCount = llvm::divideCeil(memrefSize, *arrayElemSize);
417 int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
418 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
419
420 return wrapInStructAndGetPointer(arrayType, *storageClass);
421 }
422
convertMemrefType(const spirv::TargetEnv & targetEnv,const SPIRVTypeConverter::Options & options,MemRefType type)423 static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
424 const SPIRVTypeConverter::Options &options,
425 MemRefType type) {
426 if (type.getElementType().isa<IntegerType>() &&
427 type.getElementTypeBitWidth() == 1) {
428 return convertBoolMemrefType(targetEnv, options, type);
429 }
430
431 Optional<spirv::StorageClass> storageClass =
432 SPIRVTypeConverter::getStorageClassForMemorySpace(
433 type.getMemorySpaceAsInt());
434 if (!storageClass) {
435 LLVM_DEBUG(llvm::dbgs()
436 << type << " illegal: cannot convert memory space\n");
437 return nullptr;
438 }
439
440 Type arrayElemType;
441 Type elementType = type.getElementType();
442 if (auto vecType = elementType.dyn_cast<VectorType>()) {
443 arrayElemType =
444 convertVectorType(targetEnv, options, vecType, storageClass);
445 } else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
446 arrayElemType =
447 convertScalarType(targetEnv, options, scalarType, storageClass);
448 } else {
449 LLVM_DEBUG(
450 llvm::dbgs()
451 << type
452 << " unhandled: can only convert scalar or vector element type\n");
453 return nullptr;
454 }
455 if (!arrayElemType)
456 return nullptr;
457
458 Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
459 if (!arrayElemSize) {
460 LLVM_DEBUG(llvm::dbgs()
461 << type << " illegal: cannot deduce converted element size\n");
462 return nullptr;
463 }
464
465 if (!type.hasStaticShape()) {
466 int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
467 auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, stride);
468 return wrapInStructAndGetPointer(arrayType, *storageClass);
469 }
470
471 Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
472 if (!memrefSize) {
473 LLVM_DEBUG(llvm::dbgs()
474 << type << " illegal: cannot deduce element count\n");
475 return nullptr;
476 }
477
478 auto arrayElemCount = llvm::divideCeil(*memrefSize, *arrayElemSize);
479 int64_t stride = needsExplicitLayout(*storageClass) ? *arrayElemSize : 0;
480 auto arrayType = spirv::ArrayType::get(arrayElemType, arrayElemCount, stride);
481
482 return wrapInStructAndGetPointer(arrayType, *storageClass);
483 }
484
SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,Options options)485 SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
486 Options options)
487 : targetEnv(targetAttr), options(options) {
488 // Add conversions. The order matters here: later ones will be tried earlier.
489
490 // Allow all SPIR-V dialect specific types. This assumes all builtin types
491 // adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
492 // were tried before.
493 //
494 // TODO: this assumes that the SPIR-V types are valid to use in
495 // the given target environment, which should be the case if the whole
496 // pipeline is driven by the same target environment. Still, we probably still
497 // want to validate and convert to be safe.
498 addConversion([](spirv::SPIRVType type) { return type; });
499
500 addConversion([this](IndexType /*indexType*/) { return getIndexType(); });
501
502 addConversion([this](IntegerType intType) -> Optional<Type> {
503 if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
504 return convertScalarType(this->targetEnv, this->options, scalarType);
505 return Type();
506 });
507
508 addConversion([this](FloatType floatType) -> Optional<Type> {
509 if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
510 return convertScalarType(this->targetEnv, this->options, scalarType);
511 return Type();
512 });
513
514 addConversion([this](VectorType vectorType) {
515 return convertVectorType(this->targetEnv, this->options, vectorType);
516 });
517
518 addConversion([this](TensorType tensorType) {
519 return convertTensorType(this->targetEnv, this->options, tensorType);
520 });
521
522 addConversion([this](MemRefType memRefType) {
523 return convertMemrefType(this->targetEnv, this->options, memRefType);
524 });
525 }
526
527 //===----------------------------------------------------------------------===//
528 // func::FuncOp Conversion Patterns
529 //===----------------------------------------------------------------------===//
530
531 namespace {
532 /// A pattern for rewriting function signature to convert arguments of functions
533 /// to be of valid SPIR-V types.
534 class FuncOpConversion final : public OpConversionPattern<func::FuncOp> {
535 public:
536 using OpConversionPattern<func::FuncOp>::OpConversionPattern;
537
538 LogicalResult
539 matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
540 ConversionPatternRewriter &rewriter) const override;
541 };
542 } // namespace
543
544 LogicalResult
matchAndRewrite(func::FuncOp funcOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const545 FuncOpConversion::matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
546 ConversionPatternRewriter &rewriter) const {
547 auto fnType = funcOp.getFunctionType();
548 if (fnType.getNumResults() > 1)
549 return failure();
550
551 TypeConverter::SignatureConversion signatureConverter(fnType.getNumInputs());
552 for (const auto &argType : enumerate(fnType.getInputs())) {
553 auto convertedType = getTypeConverter()->convertType(argType.value());
554 if (!convertedType)
555 return failure();
556 signatureConverter.addInputs(argType.index(), convertedType);
557 }
558
559 Type resultType;
560 if (fnType.getNumResults() == 1) {
561 resultType = getTypeConverter()->convertType(fnType.getResult(0));
562 if (!resultType)
563 return failure();
564 }
565
566 // Create the converted spv.func op.
567 auto newFuncOp = rewriter.create<spirv::FuncOp>(
568 funcOp.getLoc(), funcOp.getName(),
569 rewriter.getFunctionType(signatureConverter.getConvertedTypes(),
570 resultType ? TypeRange(resultType)
571 : TypeRange()));
572
573 // Copy over all attributes other than the function name and type.
574 for (const auto &namedAttr : funcOp->getAttrs()) {
575 if (namedAttr.getName() != FunctionOpInterface::getTypeAttrName() &&
576 namedAttr.getName() != SymbolTable::getSymbolAttrName())
577 newFuncOp->setAttr(namedAttr.getName(), namedAttr.getValue());
578 }
579
580 rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
581 newFuncOp.end());
582 if (failed(rewriter.convertRegionTypes(
583 &newFuncOp.getBody(), *getTypeConverter(), &signatureConverter)))
584 return failure();
585 rewriter.eraseOp(funcOp);
586 return success();
587 }
588
populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)589 void mlir::populateBuiltinFuncToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
590 RewritePatternSet &patterns) {
591 patterns.add<FuncOpConversion>(typeConverter, patterns.getContext());
592 }
593
594 //===----------------------------------------------------------------------===//
595 // Builtin Variables
596 //===----------------------------------------------------------------------===//
597
getBuiltinVariable(Block & body,spirv::BuiltIn builtin)598 static spirv::GlobalVariableOp getBuiltinVariable(Block &body,
599 spirv::BuiltIn builtin) {
600 // Look through all global variables in the given `body` block and check if
601 // there is a spv.GlobalVariable that has the same `builtin` attribute.
602 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
603 if (auto builtinAttr = varOp->getAttrOfType<StringAttr>(
604 spirv::SPIRVDialect::getAttributeName(
605 spirv::Decoration::BuiltIn))) {
606 auto varBuiltIn = spirv::symbolizeBuiltIn(builtinAttr.getValue());
607 if (varBuiltIn && *varBuiltIn == builtin) {
608 return varOp;
609 }
610 }
611 }
612 return nullptr;
613 }
614
615 /// Gets name of global variable for a builtin.
getBuiltinVarName(spirv::BuiltIn builtin)616 static std::string getBuiltinVarName(spirv::BuiltIn builtin) {
617 return std::string("__builtin_var_") + stringifyBuiltIn(builtin).str() + "__";
618 }
619
620 /// Gets or inserts a global variable for a builtin within `body` block.
621 static spirv::GlobalVariableOp
getOrInsertBuiltinVariable(Block & body,Location loc,spirv::BuiltIn builtin,Type integerType,OpBuilder & builder)622 getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
623 Type integerType, OpBuilder &builder) {
624 if (auto varOp = getBuiltinVariable(body, builtin))
625 return varOp;
626
627 OpBuilder::InsertionGuard guard(builder);
628 builder.setInsertionPointToStart(&body);
629
630 spirv::GlobalVariableOp newVarOp;
631 switch (builtin) {
632 case spirv::BuiltIn::NumWorkgroups:
633 case spirv::BuiltIn::WorkgroupSize:
634 case spirv::BuiltIn::WorkgroupId:
635 case spirv::BuiltIn::LocalInvocationId:
636 case spirv::BuiltIn::GlobalInvocationId: {
637 auto ptrType = spirv::PointerType::get(VectorType::get({3}, integerType),
638 spirv::StorageClass::Input);
639 std::string name = getBuiltinVarName(builtin);
640 newVarOp =
641 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
642 break;
643 }
644 case spirv::BuiltIn::SubgroupId:
645 case spirv::BuiltIn::NumSubgroups:
646 case spirv::BuiltIn::SubgroupSize: {
647 auto ptrType =
648 spirv::PointerType::get(integerType, spirv::StorageClass::Input);
649 std::string name = getBuiltinVarName(builtin);
650 newVarOp =
651 builder.create<spirv::GlobalVariableOp>(loc, ptrType, name, builtin);
652 break;
653 }
654 default:
655 emitError(loc, "unimplemented builtin variable generation for ")
656 << stringifyBuiltIn(builtin);
657 }
658 return newVarOp;
659 }
660
getBuiltinVariableValue(Operation * op,spirv::BuiltIn builtin,Type integerType,OpBuilder & builder)661 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
662 spirv::BuiltIn builtin,
663 Type integerType,
664 OpBuilder &builder) {
665 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
666 if (!parent) {
667 op->emitError("expected operation to be within a module-like op");
668 return nullptr;
669 }
670
671 spirv::GlobalVariableOp varOp =
672 getOrInsertBuiltinVariable(*parent->getRegion(0).begin(), op->getLoc(),
673 builtin, integerType, builder);
674 Value ptr = builder.create<spirv::AddressOfOp>(op->getLoc(), varOp);
675 return builder.create<spirv::LoadOp>(op->getLoc(), ptr);
676 }
677
678 //===----------------------------------------------------------------------===//
679 // Push constant storage
680 //===----------------------------------------------------------------------===//
681
682 /// Returns the pointer type for the push constant storage containing
683 /// `elementCount` 32-bit integer values.
getPushConstantStorageType(unsigned elementCount,Builder & builder,Type indexType)684 static spirv::PointerType getPushConstantStorageType(unsigned elementCount,
685 Builder &builder,
686 Type indexType) {
687 auto arrayType = spirv::ArrayType::get(indexType, elementCount,
688 /*stride=*/4);
689 auto structType = spirv::StructType::get({arrayType}, /*offsetInfo=*/0);
690 return spirv::PointerType::get(structType, spirv::StorageClass::PushConstant);
691 }
692
693 /// Returns the push constant varible containing `elementCount` 32-bit integer
694 /// values in `body`. Returns null op if such an op does not exit.
getPushConstantVariable(Block & body,unsigned elementCount)695 static spirv::GlobalVariableOp getPushConstantVariable(Block &body,
696 unsigned elementCount) {
697 for (auto varOp : body.getOps<spirv::GlobalVariableOp>()) {
698 auto ptrType = varOp.type().dyn_cast<spirv::PointerType>();
699 if (!ptrType)
700 continue;
701
702 // Note that Vulkan requires "There must be no more than one push constant
703 // block statically used per shader entry point." So we should always reuse
704 // the existing one.
705 if (ptrType.getStorageClass() == spirv::StorageClass::PushConstant) {
706 auto numElements = ptrType.getPointeeType()
707 .cast<spirv::StructType>()
708 .getElementType(0)
709 .cast<spirv::ArrayType>()
710 .getNumElements();
711 if (numElements == elementCount)
712 return varOp;
713 }
714 }
715 return nullptr;
716 }
717
718 /// Gets or inserts a global variable for push constant storage containing
719 /// `elementCount` 32-bit integer values in `block`.
720 static spirv::GlobalVariableOp
getOrInsertPushConstantVariable(Location loc,Block & block,unsigned elementCount,OpBuilder & b,Type indexType)721 getOrInsertPushConstantVariable(Location loc, Block &block,
722 unsigned elementCount, OpBuilder &b,
723 Type indexType) {
724 if (auto varOp = getPushConstantVariable(block, elementCount))
725 return varOp;
726
727 auto builder = OpBuilder::atBlockBegin(&block, b.getListener());
728 auto type = getPushConstantStorageType(elementCount, builder, indexType);
729 const char *name = "__push_constant_var__";
730 return builder.create<spirv::GlobalVariableOp>(loc, type, name,
731 /*initializer=*/nullptr);
732 }
733
getPushConstantValue(Operation * op,unsigned elementCount,unsigned offset,Type integerType,OpBuilder & builder)734 Value spirv::getPushConstantValue(Operation *op, unsigned elementCount,
735 unsigned offset, Type integerType,
736 OpBuilder &builder) {
737 Location loc = op->getLoc();
738 Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
739 if (!parent) {
740 op->emitError("expected operation to be within a module-like op");
741 return nullptr;
742 }
743
744 spirv::GlobalVariableOp varOp = getOrInsertPushConstantVariable(
745 loc, parent->getRegion(0).front(), elementCount, builder, integerType);
746
747 Value zeroOp = spirv::ConstantOp::getZero(integerType, loc, builder);
748 Value offsetOp = builder.create<spirv::ConstantOp>(
749 loc, integerType, builder.getI32IntegerAttr(offset));
750 auto addrOp = builder.create<spirv::AddressOfOp>(loc, varOp);
751 auto acOp = builder.create<spirv::AccessChainOp>(
752 loc, addrOp, llvm::makeArrayRef({zeroOp, offsetOp}));
753 return builder.create<spirv::LoadOp>(loc, acOp);
754 }
755
756 //===----------------------------------------------------------------------===//
757 // Index calculation
758 //===----------------------------------------------------------------------===//
759
linearizeIndex(ValueRange indices,ArrayRef<int64_t> strides,int64_t offset,Type integerType,Location loc,OpBuilder & builder)760 Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
761 int64_t offset, Type integerType,
762 Location loc, OpBuilder &builder) {
763 assert(indices.size() == strides.size() &&
764 "must provide indices for all dimensions");
765
766 // TODO: Consider moving to use affine.apply and patterns converting
767 // affine.apply to standard ops. This needs converting to SPIR-V passes to be
768 // broken down into progressive small steps so we can have intermediate steps
769 // using other dialects. At the moment SPIR-V is the final sink.
770
771 Value linearizedIndex = builder.create<spirv::ConstantOp>(
772 loc, integerType, IntegerAttr::get(integerType, offset));
773 for (const auto &index : llvm::enumerate(indices)) {
774 Value strideVal = builder.create<spirv::ConstantOp>(
775 loc, integerType,
776 IntegerAttr::get(integerType, strides[index.index()]));
777 Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
778 linearizedIndex =
779 builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
780 }
781 return linearizedIndex;
782 }
783
getElementPtr(SPIRVTypeConverter & typeConverter,MemRefType baseType,Value basePtr,ValueRange indices,Location loc,OpBuilder & builder)784 spirv::AccessChainOp mlir::spirv::getElementPtr(
785 SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr,
786 ValueRange indices, Location loc, OpBuilder &builder) {
787 // Get base and offset of the MemRefType and verify they are static.
788
789 int64_t offset;
790 SmallVector<int64_t, 4> strides;
791 if (failed(getStridesAndOffset(baseType, strides, offset)) ||
792 llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()) ||
793 offset == MemRefType::getDynamicStrideOrOffset()) {
794 return nullptr;
795 }
796
797 auto indexType = typeConverter.getIndexType();
798
799 SmallVector<Value, 2> linearizedIndices;
800 auto zero = spirv::ConstantOp::getZero(indexType, loc, builder);
801
802 // Add a '0' at the start to index into the struct.
803 linearizedIndices.push_back(zero);
804
805 if (baseType.getRank() == 0) {
806 linearizedIndices.push_back(zero);
807 } else {
808 linearizedIndices.push_back(
809 linearizeIndex(indices, strides, offset, indexType, loc, builder));
810 }
811 return builder.create<spirv::AccessChainOp>(loc, basePtr, linearizedIndices);
812 }
813
814 //===----------------------------------------------------------------------===//
815 // SPIR-V ConversionTarget
816 //===----------------------------------------------------------------------===//
817
818 std::unique_ptr<SPIRVConversionTarget>
get(spirv::TargetEnvAttr targetAttr)819 SPIRVConversionTarget::get(spirv::TargetEnvAttr targetAttr) {
820 std::unique_ptr<SPIRVConversionTarget> target(
821 // std::make_unique does not work here because the constructor is private.
822 new SPIRVConversionTarget(targetAttr));
823 SPIRVConversionTarget *targetPtr = target.get();
824 target->addDynamicallyLegalDialect<spirv::SPIRVDialect>(
825 // We need to capture the raw pointer here because it is stable:
826 // target will be destroyed once this function is returned.
827 [targetPtr](Operation *op) { return targetPtr->isLegalOp(op); });
828 return target;
829 }
830
SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)831 SPIRVConversionTarget::SPIRVConversionTarget(spirv::TargetEnvAttr targetAttr)
832 : ConversionTarget(*targetAttr.getContext()), targetEnv(targetAttr) {}
833
isLegalOp(Operation * op)834 bool SPIRVConversionTarget::isLegalOp(Operation *op) {
835 // Make sure this op is available at the given version. Ops not implementing
836 // QueryMinVersionInterface/QueryMaxVersionInterface are available to all
837 // SPIR-V versions.
838 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
839 Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
840 if (minVersion && *minVersion > this->targetEnv.getVersion()) {
841 LLVM_DEBUG(llvm::dbgs()
842 << op->getName() << " illegal: requiring min version "
843 << spirv::stringifyVersion(*minVersion) << "\n");
844 return false;
845 }
846 }
847 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
848 Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
849 if (maxVersion && *maxVersion < this->targetEnv.getVersion()) {
850 LLVM_DEBUG(llvm::dbgs()
851 << op->getName() << " illegal: requiring max version "
852 << spirv::stringifyVersion(*maxVersion) << "\n");
853 return false;
854 }
855 }
856
857 // Make sure this op's required extensions are allowed to use. Ops not
858 // implementing QueryExtensionInterface do not require extensions to be
859 // available.
860 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
861 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
862 extensions.getExtensions())))
863 return false;
864
865 // Make sure this op's required extensions are allowed to use. Ops not
866 // implementing QueryCapabilityInterface do not require capabilities to be
867 // available.
868 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
869 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
870 capabilities.getCapabilities())))
871 return false;
872
873 SmallVector<Type, 4> valueTypes;
874 valueTypes.append(op->operand_type_begin(), op->operand_type_end());
875 valueTypes.append(op->result_type_begin(), op->result_type_end());
876
877 // Ensure that all types have been converted to SPIRV types.
878 if (llvm::any_of(valueTypes,
879 [](Type t) { return !t.isa<spirv::SPIRVType>(); }))
880 return false;
881
882 // Special treatment for global variables, whose type requirements are
883 // conveyed by type attributes.
884 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
885 valueTypes.push_back(globalVar.type());
886
887 // Make sure the op's operands/results use types that are allowed by the
888 // target environment.
889 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
890 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
891 for (Type valueType : valueTypes) {
892 typeExtensions.clear();
893 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
894 if (failed(checkExtensionRequirements(op->getName(), this->targetEnv,
895 typeExtensions)))
896 return false;
897
898 typeCapabilities.clear();
899 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
900 if (failed(checkCapabilityRequirements(op->getName(), this->targetEnv,
901 typeCapabilities)))
902 return false;
903 }
904
905 return true;
906 }
907