1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Serializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Support/LogicalResult.h" 19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/SmallPtrSet.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/ADT/bit.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "spirv-serialization" 28 29 using namespace mlir; 30 31 /// Returns the merge block if the given `op` is a structured control flow op. 32 /// Otherwise returns nullptr. 33 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { 34 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) 35 return selectionOp.getMergeBlock(); 36 if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) 37 return loopOp.getMergeBlock(); 38 return nullptr; 39 } 40 41 /// Given a predecessor `block` for a block with arguments, returns the block 42 /// that should be used as the parent block for SPIR-V OpPhi instructions 43 /// corresponding to the block arguments. 44 static Block *getPhiIncomingBlock(Block *block) { 45 // If the predecessor block in question is the entry block for a 46 // spv.mlir.loop, we jump to this spv.mlir.loop from its enclosing block. 47 if (block->isEntryBlock()) { 48 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { 49 // Then the incoming parent block for OpPhi should be the merge block of 50 // the structured control flow op before this loop. 51 Operation *op = loopOp.getOperation(); 52 while ((op = op->getPrevNode()) != nullptr) 53 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) 54 return incomingBlock; 55 // Or the enclosing block itself if no structured control flow ops 56 // exists before this loop. 57 return loopOp->getBlock(); 58 } 59 } 60 61 // Otherwise, we jump from the given predecessor block. Try to see if there is 62 // a structured control flow op inside it. 63 for (Operation &op : llvm::reverse(block->getOperations())) { 64 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) 65 return incomingBlock; 66 } 67 return block; 68 } 69 70 namespace mlir { 71 namespace spirv { 72 73 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into 74 /// the given `binary` vector. 75 LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, 76 spirv::Opcode op, 77 ArrayRef<uint32_t> operands) { 78 uint32_t wordCount = 1 + operands.size(); 79 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 80 binary.append(operands.begin(), operands.end()); 81 return success(); 82 } 83 84 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) 85 : module(module), mlirBuilder(module.getContext()), 86 emitDebugInfo(emitDebugInfo) {} 87 88 LogicalResult Serializer::serialize() { 89 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); 90 91 if (failed(module.verify())) 92 return failure(); 93 94 // TODO: handle the other sections 95 processCapability(); 96 processExtension(); 97 processMemoryModel(); 98 processDebugInfo(); 99 100 // Iterate over the module body to serialize it. Assumptions are that there is 101 // only one basic block in the moduleOp 102 for (auto &op : *module.getBody()) { 103 if (failed(processOperation(&op))) { 104 return failure(); 105 } 106 } 107 108 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); 109 return success(); 110 } 111 112 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { 113 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + 114 extensions.size() + extendedSets.size() + 115 memoryModel.size() + entryPoints.size() + 116 executionModes.size() + decorations.size() + 117 typesGlobalValues.size() + functions.size(); 118 119 binary.clear(); 120 binary.reserve(moduleSize); 121 122 spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); 123 binary.append(capabilities.begin(), capabilities.end()); 124 binary.append(extensions.begin(), extensions.end()); 125 binary.append(extendedSets.begin(), extendedSets.end()); 126 binary.append(memoryModel.begin(), memoryModel.end()); 127 binary.append(entryPoints.begin(), entryPoints.end()); 128 binary.append(executionModes.begin(), executionModes.end()); 129 binary.append(debug.begin(), debug.end()); 130 binary.append(names.begin(), names.end()); 131 binary.append(decorations.begin(), decorations.end()); 132 binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); 133 binary.append(functions.begin(), functions.end()); 134 } 135 136 #ifndef NDEBUG 137 void Serializer::printValueIDMap(raw_ostream &os) { 138 os << "\n= Value <id> Map =\n\n"; 139 for (auto valueIDPair : valueIDMap) { 140 Value val = valueIDPair.first; 141 os << " " << val << " " 142 << "id = " << valueIDPair.second << ' '; 143 if (auto *op = val.getDefiningOp()) { 144 os << "from op '" << op->getName() << "'"; 145 } else if (auto arg = val.dyn_cast<BlockArgument>()) { 146 Block *block = arg.getOwner(); 147 os << "from argument of block " << block << ' '; 148 os << " in op '" << block->getParentOp()->getName() << "'"; 149 } 150 os << '\n'; 151 } 152 } 153 #endif 154 155 //===----------------------------------------------------------------------===// 156 // Module structure 157 //===----------------------------------------------------------------------===// 158 159 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { 160 auto funcID = funcIDMap.lookup(fnName); 161 if (!funcID) { 162 funcID = getNextID(); 163 funcIDMap[fnName] = funcID; 164 } 165 return funcID; 166 } 167 168 void Serializer::processCapability() { 169 for (auto cap : module.vce_triple()->getCapabilities()) 170 (void)encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, 171 {static_cast<uint32_t>(cap)}); 172 } 173 174 void Serializer::processDebugInfo() { 175 if (!emitDebugInfo) 176 return; 177 auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>(); 178 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>"; 179 fileID = getNextID(); 180 SmallVector<uint32_t, 16> operands; 181 operands.push_back(fileID); 182 (void)spirv::encodeStringLiteralInto(operands, fileName); 183 (void)encodeInstructionInto(debug, spirv::Opcode::OpString, operands); 184 // TODO: Encode more debug instructions. 185 } 186 187 void Serializer::processExtension() { 188 llvm::SmallVector<uint32_t, 16> extName; 189 for (spirv::Extension ext : module.vce_triple()->getExtensions()) { 190 extName.clear(); 191 (void)spirv::encodeStringLiteralInto(extName, 192 spirv::stringifyExtension(ext)); 193 (void)encodeInstructionInto(extensions, spirv::Opcode::OpExtension, 194 extName); 195 } 196 } 197 198 void Serializer::processMemoryModel() { 199 uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt(); 200 uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt(); 201 202 (void)encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, 203 {am, mm}); 204 } 205 206 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, 207 NamedAttribute attr) { 208 auto attrName = attr.first.strref(); 209 auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); 210 auto decoration = spirv::symbolizeDecoration(decorationName); 211 if (!decoration) { 212 return emitError( 213 loc, "non-argument attributes expected to have snake-case-ified " 214 "decoration name, unhandled attribute with name : ") 215 << attrName; 216 } 217 SmallVector<uint32_t, 1> args; 218 switch (decoration.getValue()) { 219 case spirv::Decoration::Binding: 220 case spirv::Decoration::DescriptorSet: 221 case spirv::Decoration::Location: 222 if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) { 223 args.push_back(intAttr.getValue().getZExtValue()); 224 break; 225 } 226 return emitError(loc, "expected integer attribute for ") << attrName; 227 case spirv::Decoration::BuiltIn: 228 if (auto strAttr = attr.second.dyn_cast<StringAttr>()) { 229 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); 230 if (enumVal) { 231 args.push_back(static_cast<uint32_t>(enumVal.getValue())); 232 break; 233 } 234 return emitError(loc, "invalid ") 235 << attrName << " attribute " << strAttr.getValue(); 236 } 237 return emitError(loc, "expected string attribute for ") << attrName; 238 case spirv::Decoration::Aliased: 239 case spirv::Decoration::Flat: 240 case spirv::Decoration::NonReadable: 241 case spirv::Decoration::NonWritable: 242 case spirv::Decoration::NoPerspective: 243 case spirv::Decoration::Restrict: 244 case spirv::Decoration::RelaxedPrecision: 245 // For unit attributes, the args list has no values so we do nothing 246 if (auto unitAttr = attr.second.dyn_cast<UnitAttr>()) 247 break; 248 return emitError(loc, "expected unit attribute for ") << attrName; 249 default: 250 return emitError(loc, "unhandled decoration ") << decorationName; 251 } 252 return emitDecoration(resultID, decoration.getValue(), args); 253 } 254 255 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { 256 assert(!name.empty() && "unexpected empty string for OpName"); 257 258 SmallVector<uint32_t, 4> nameOperands; 259 nameOperands.push_back(resultID); 260 if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { 261 return failure(); 262 } 263 return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); 264 } 265 266 template <> 267 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( 268 Location loc, spirv::ArrayType type, uint32_t resultID) { 269 if (unsigned stride = type.getArrayStride()) { 270 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 271 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 272 } 273 return success(); 274 } 275 276 template <> 277 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( 278 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) { 279 if (unsigned stride = type.getArrayStride()) { 280 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 281 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 282 } 283 return success(); 284 } 285 286 LogicalResult Serializer::processMemberDecoration( 287 uint32_t structID, 288 const spirv::StructType::MemberDecorationInfo &memberDecoration) { 289 SmallVector<uint32_t, 4> args( 290 {structID, memberDecoration.memberIndex, 291 static_cast<uint32_t>(memberDecoration.decoration)}); 292 if (memberDecoration.hasValue) { 293 args.push_back(memberDecoration.decorationValue); 294 } 295 return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, 296 args); 297 } 298 299 //===----------------------------------------------------------------------===// 300 // Type 301 //===----------------------------------------------------------------------===// 302 303 // According to the SPIR-V spec "Validation Rules for Shader Capabilities": 304 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and 305 // PushConstant Storage Classes must be explicitly laid out." 306 bool Serializer::isInterfaceStructPtrType(Type type) const { 307 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 308 switch (ptrType.getStorageClass()) { 309 case spirv::StorageClass::PhysicalStorageBuffer: 310 case spirv::StorageClass::PushConstant: 311 case spirv::StorageClass::StorageBuffer: 312 case spirv::StorageClass::Uniform: 313 return ptrType.getPointeeType().isa<spirv::StructType>(); 314 default: 315 break; 316 } 317 } 318 return false; 319 } 320 321 LogicalResult Serializer::processType(Location loc, Type type, 322 uint32_t &typeID) { 323 // Maintains a set of names for nested identified struct types. This is used 324 // to properly serialize recursive references. 325 SetVector<StringRef> serializationCtx; 326 return processTypeImpl(loc, type, typeID, serializationCtx); 327 } 328 329 LogicalResult 330 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, 331 SetVector<StringRef> &serializationCtx) { 332 typeID = getTypeID(type); 333 if (typeID) { 334 return success(); 335 } 336 typeID = getNextID(); 337 SmallVector<uint32_t, 4> operands; 338 339 operands.push_back(typeID); 340 auto typeEnum = spirv::Opcode::OpTypeVoid; 341 bool deferSerialization = false; 342 343 if ((type.isa<FunctionType>() && 344 succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum, 345 operands))) || 346 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, 347 deferSerialization, serializationCtx))) { 348 if (deferSerialization) 349 return success(); 350 351 typeIDMap[type] = typeID; 352 353 if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) 354 return failure(); 355 356 if (recursiveStructInfos.count(type) != 0) { 357 // This recursive struct type is emitted already, now the OpTypePointer 358 // instructions referring to recursive references are emitted as well. 359 for (auto &ptrInfo : recursiveStructInfos[type]) { 360 // TODO: This might not work if more than 1 recursive reference is 361 // present in the struct. 362 SmallVector<uint32_t, 4> ptrOperands; 363 ptrOperands.push_back(ptrInfo.pointerTypeID); 364 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass)); 365 ptrOperands.push_back(typeIDMap[type]); 366 367 if (failed(encodeInstructionInto( 368 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) 369 return failure(); 370 } 371 372 recursiveStructInfos[type].clear(); 373 } 374 375 return success(); 376 } 377 378 return failure(); 379 } 380 381 LogicalResult Serializer::prepareBasicType( 382 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, 383 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, 384 SetVector<StringRef> &serializationCtx) { 385 deferSerialization = false; 386 387 if (isVoidType(type)) { 388 typeEnum = spirv::Opcode::OpTypeVoid; 389 return success(); 390 } 391 392 if (auto intType = type.dyn_cast<IntegerType>()) { 393 if (intType.getWidth() == 1) { 394 typeEnum = spirv::Opcode::OpTypeBool; 395 return success(); 396 } 397 398 typeEnum = spirv::Opcode::OpTypeInt; 399 operands.push_back(intType.getWidth()); 400 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 401 // to preserve or validate. 402 // 0 indicates unsigned, or no signedness semantics 403 // 1 indicates signed semantics." 404 operands.push_back(intType.isSigned() ? 1 : 0); 405 return success(); 406 } 407 408 if (auto floatType = type.dyn_cast<FloatType>()) { 409 typeEnum = spirv::Opcode::OpTypeFloat; 410 operands.push_back(floatType.getWidth()); 411 return success(); 412 } 413 414 if (auto vectorType = type.dyn_cast<VectorType>()) { 415 uint32_t elementTypeID = 0; 416 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, 417 serializationCtx))) { 418 return failure(); 419 } 420 typeEnum = spirv::Opcode::OpTypeVector; 421 operands.push_back(elementTypeID); 422 operands.push_back(vectorType.getNumElements()); 423 return success(); 424 } 425 426 if (auto imageType = type.dyn_cast<spirv::ImageType>()) { 427 typeEnum = spirv::Opcode::OpTypeImage; 428 uint32_t sampledTypeID = 0; 429 if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) 430 return failure(); 431 432 operands.push_back(sampledTypeID); 433 operands.push_back(static_cast<uint32_t>(imageType.getDim())); 434 operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo())); 435 operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo())); 436 operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo())); 437 operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo())); 438 operands.push_back(static_cast<uint32_t>(imageType.getImageFormat())); 439 return success(); 440 } 441 442 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) { 443 typeEnum = spirv::Opcode::OpTypeArray; 444 uint32_t elementTypeID = 0; 445 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, 446 serializationCtx))) { 447 return failure(); 448 } 449 operands.push_back(elementTypeID); 450 if (auto elementCountID = prepareConstantInt( 451 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { 452 operands.push_back(elementCountID); 453 } 454 return processTypeDecoration(loc, arrayType, resultID); 455 } 456 457 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 458 uint32_t pointeeTypeID = 0; 459 spirv::StructType pointeeStruct = 460 ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 461 462 if (pointeeStruct && pointeeStruct.isIdentified() && 463 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { 464 // A recursive reference to an enclosing struct is found. 465 // 466 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage 467 // class as operands. 468 SmallVector<uint32_t, 2> forwardPtrOperands; 469 forwardPtrOperands.push_back(resultID); 470 forwardPtrOperands.push_back( 471 static_cast<uint32_t>(ptrType.getStorageClass())); 472 473 (void)encodeInstructionInto(typesGlobalValues, 474 spirv::Opcode::OpTypeForwardPointer, 475 forwardPtrOperands); 476 477 // 2. Find the pointee (enclosing) struct. 478 auto structType = spirv::StructType::getIdentified( 479 module.getContext(), pointeeStruct.getIdentifier()); 480 481 if (!structType) 482 return failure(); 483 484 // 3. Mark the OpTypePointer that is supposed to be emitted by this call 485 // as deferred. 486 deferSerialization = true; 487 488 // 4. Record the info needed to emit the deferred OpTypePointer 489 // instruction when the enclosing struct is completely serialized. 490 recursiveStructInfos[structType].push_back( 491 {resultID, ptrType.getStorageClass()}); 492 } else { 493 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, 494 serializationCtx))) 495 return failure(); 496 } 497 498 typeEnum = spirv::Opcode::OpTypePointer; 499 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); 500 operands.push_back(pointeeTypeID); 501 return success(); 502 } 503 504 if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) { 505 uint32_t elementTypeID = 0; 506 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), 507 elementTypeID, serializationCtx))) { 508 return failure(); 509 } 510 typeEnum = spirv::Opcode::OpTypeRuntimeArray; 511 operands.push_back(elementTypeID); 512 return processTypeDecoration(loc, runtimeArrayType, resultID); 513 } 514 515 if (auto sampledImageType = type.dyn_cast<spirv::SampledImageType>()) { 516 typeEnum = spirv::Opcode::OpTypeSampledImage; 517 uint32_t imageTypeID = 0; 518 if (failed( 519 processType(loc, sampledImageType.getImageType(), imageTypeID))) { 520 return failure(); 521 } 522 operands.push_back(imageTypeID); 523 return success(); 524 } 525 526 if (auto structType = type.dyn_cast<spirv::StructType>()) { 527 if (structType.isIdentified()) { 528 (void)processName(resultID, structType.getIdentifier()); 529 serializationCtx.insert(structType.getIdentifier()); 530 } 531 532 bool hasOffset = structType.hasOffset(); 533 for (auto elementIndex : 534 llvm::seq<uint32_t>(0, structType.getNumElements())) { 535 uint32_t elementTypeID = 0; 536 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), 537 elementTypeID, serializationCtx))) { 538 return failure(); 539 } 540 operands.push_back(elementTypeID); 541 if (hasOffset) { 542 // Decorate each struct member with an offset 543 spirv::StructType::MemberDecorationInfo offsetDecoration{ 544 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, 545 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; 546 if (failed(processMemberDecoration(resultID, offsetDecoration))) { 547 return emitError(loc, "cannot decorate ") 548 << elementIndex << "-th member of " << structType 549 << " with its offset"; 550 } 551 } 552 } 553 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 554 structType.getMemberDecorations(memberDecorations); 555 556 for (auto &memberDecoration : memberDecorations) { 557 if (failed(processMemberDecoration(resultID, memberDecoration))) { 558 return emitError(loc, "cannot decorate ") 559 << static_cast<uint32_t>(memberDecoration.memberIndex) 560 << "-th member of " << structType << " with " 561 << stringifyDecoration(memberDecoration.decoration); 562 } 563 } 564 565 typeEnum = spirv::Opcode::OpTypeStruct; 566 567 if (structType.isIdentified()) 568 serializationCtx.remove(structType.getIdentifier()); 569 570 return success(); 571 } 572 573 if (auto cooperativeMatrixType = 574 type.dyn_cast<spirv::CooperativeMatrixNVType>()) { 575 uint32_t elementTypeID = 0; 576 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), 577 elementTypeID, serializationCtx))) { 578 return failure(); 579 } 580 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; 581 auto getConstantOp = [&](uint32_t id) { 582 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); 583 return prepareConstantInt(loc, attr); 584 }; 585 operands.push_back(elementTypeID); 586 operands.push_back( 587 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope()))); 588 operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); 589 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); 590 return success(); 591 } 592 593 if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) { 594 uint32_t elementTypeID = 0; 595 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, 596 serializationCtx))) { 597 return failure(); 598 } 599 typeEnum = spirv::Opcode::OpTypeMatrix; 600 operands.push_back(elementTypeID); 601 operands.push_back(matrixType.getNumColumns()); 602 return success(); 603 } 604 605 // TODO: Handle other types. 606 return emitError(loc, "unhandled type in serialization: ") << type; 607 } 608 609 LogicalResult 610 Serializer::prepareFunctionType(Location loc, FunctionType type, 611 spirv::Opcode &typeEnum, 612 SmallVectorImpl<uint32_t> &operands) { 613 typeEnum = spirv::Opcode::OpTypeFunction; 614 assert(type.getNumResults() <= 1 && 615 "serialization supports only a single return value"); 616 uint32_t resultID = 0; 617 if (failed(processType( 618 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), 619 resultID))) { 620 return failure(); 621 } 622 operands.push_back(resultID); 623 for (auto &res : type.getInputs()) { 624 uint32_t argTypeID = 0; 625 if (failed(processType(loc, res, argTypeID))) { 626 return failure(); 627 } 628 operands.push_back(argTypeID); 629 } 630 return success(); 631 } 632 633 //===----------------------------------------------------------------------===// 634 // Constant 635 //===----------------------------------------------------------------------===// 636 637 uint32_t Serializer::prepareConstant(Location loc, Type constType, 638 Attribute valueAttr) { 639 if (auto id = prepareConstantScalar(loc, valueAttr)) { 640 return id; 641 } 642 643 // This is a composite literal. We need to handle each component separately 644 // and then emit an OpConstantComposite for the whole. 645 646 if (auto id = getConstantID(valueAttr)) { 647 return id; 648 } 649 650 uint32_t typeID = 0; 651 if (failed(processType(loc, constType, typeID))) { 652 return 0; 653 } 654 655 uint32_t resultID = 0; 656 if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) { 657 int rank = attr.getType().dyn_cast<ShapedType>().getRank(); 658 SmallVector<uint64_t, 4> index(rank); 659 resultID = prepareDenseElementsConstant(loc, constType, attr, 660 /*dim=*/0, index); 661 } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) { 662 resultID = prepareArrayConstant(loc, constType, arrayAttr); 663 } 664 665 if (resultID == 0) { 666 emitError(loc, "cannot serialize attribute: ") << valueAttr; 667 return 0; 668 } 669 670 constIDMap[valueAttr] = resultID; 671 return resultID; 672 } 673 674 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, 675 ArrayAttr attr) { 676 uint32_t typeID = 0; 677 if (failed(processType(loc, constType, typeID))) { 678 return 0; 679 } 680 681 uint32_t resultID = getNextID(); 682 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 683 operands.reserve(attr.size() + 2); 684 auto elementType = constType.cast<spirv::ArrayType>().getElementType(); 685 for (Attribute elementAttr : attr) { 686 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { 687 operands.push_back(elementID); 688 } else { 689 return 0; 690 } 691 } 692 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 693 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 694 695 return resultID; 696 } 697 698 // TODO: Turn the below function into iterative function, instead of 699 // recursive function. 700 uint32_t 701 Serializer::prepareDenseElementsConstant(Location loc, Type constType, 702 DenseElementsAttr valueAttr, int dim, 703 MutableArrayRef<uint64_t> index) { 704 auto shapedType = valueAttr.getType().dyn_cast<ShapedType>(); 705 assert(dim <= shapedType.getRank()); 706 if (shapedType.getRank() == dim) { 707 if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) { 708 return attr.getType().getElementType().isInteger(1) 709 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index]) 710 : prepareConstantInt(loc, 711 attr.getValues<IntegerAttr>()[index]); 712 } 713 if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) { 714 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]); 715 } 716 return 0; 717 } 718 719 uint32_t typeID = 0; 720 if (failed(processType(loc, constType, typeID))) { 721 return 0; 722 } 723 724 uint32_t resultID = getNextID(); 725 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 726 operands.reserve(shapedType.getDimSize(dim) + 2); 727 auto elementType = constType.cast<spirv::CompositeType>().getElementType(0); 728 for (int i = 0; i < shapedType.getDimSize(dim); ++i) { 729 index[dim] = i; 730 if (auto elementID = prepareDenseElementsConstant( 731 loc, elementType, valueAttr, dim + 1, index)) { 732 operands.push_back(elementID); 733 } else { 734 return 0; 735 } 736 } 737 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 738 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 739 740 return resultID; 741 } 742 743 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, 744 bool isSpec) { 745 if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) { 746 return prepareConstantFp(loc, floatAttr, isSpec); 747 } 748 if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) { 749 return prepareConstantBool(loc, boolAttr, isSpec); 750 } 751 if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) { 752 return prepareConstantInt(loc, intAttr, isSpec); 753 } 754 755 return 0; 756 } 757 758 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, 759 bool isSpec) { 760 if (!isSpec) { 761 // We can de-duplicate normal constants, but not specialization constants. 762 if (auto id = getConstantID(boolAttr)) { 763 return id; 764 } 765 } 766 767 // Process the type for this bool literal 768 uint32_t typeID = 0; 769 if (failed(processType(loc, boolAttr.getType(), typeID))) { 770 return 0; 771 } 772 773 auto resultID = getNextID(); 774 auto opcode = boolAttr.getValue() 775 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue 776 : spirv::Opcode::OpConstantTrue) 777 : (isSpec ? spirv::Opcode::OpSpecConstantFalse 778 : spirv::Opcode::OpConstantFalse); 779 (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); 780 781 if (!isSpec) { 782 constIDMap[boolAttr] = resultID; 783 } 784 return resultID; 785 } 786 787 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, 788 bool isSpec) { 789 if (!isSpec) { 790 // We can de-duplicate normal constants, but not specialization constants. 791 if (auto id = getConstantID(intAttr)) { 792 return id; 793 } 794 } 795 796 // Process the type for this integer literal 797 uint32_t typeID = 0; 798 if (failed(processType(loc, intAttr.getType(), typeID))) { 799 return 0; 800 } 801 802 auto resultID = getNextID(); 803 APInt value = intAttr.getValue(); 804 unsigned bitwidth = value.getBitWidth(); 805 bool isSigned = value.isSignedIntN(bitwidth); 806 807 auto opcode = 808 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 809 810 switch (bitwidth) { 811 // According to SPIR-V spec, "When the type's bit width is less than 812 // 32-bits, the literal's value appears in the low-order bits of the word, 813 // and the high-order bits must be 0 for a floating-point type, or 0 for an 814 // integer type with Signedness of 0, or sign extended when Signedness 815 // is 1." 816 case 32: 817 case 16: 818 case 8: { 819 uint32_t word = 0; 820 if (isSigned) { 821 word = static_cast<int32_t>(value.getSExtValue()); 822 } else { 823 word = static_cast<uint32_t>(value.getZExtValue()); 824 } 825 (void)encodeInstructionInto(typesGlobalValues, opcode, 826 {typeID, resultID, word}); 827 } break; 828 // According to SPIR-V spec: "When the type's bit width is larger than one 829 // word, the literal’s low-order words appear first." 830 case 64: { 831 struct DoubleWord { 832 uint32_t word1; 833 uint32_t word2; 834 } words; 835 if (isSigned) { 836 words = llvm::bit_cast<DoubleWord>(value.getSExtValue()); 837 } else { 838 words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); 839 } 840 (void)encodeInstructionInto(typesGlobalValues, opcode, 841 {typeID, resultID, words.word1, words.word2}); 842 } break; 843 default: { 844 std::string valueStr; 845 llvm::raw_string_ostream rss(valueStr); 846 value.print(rss, /*isSigned=*/false); 847 848 emitError(loc, "cannot serialize ") 849 << bitwidth << "-bit integer literal: " << rss.str(); 850 return 0; 851 } 852 } 853 854 if (!isSpec) { 855 constIDMap[intAttr] = resultID; 856 } 857 return resultID; 858 } 859 860 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, 861 bool isSpec) { 862 if (!isSpec) { 863 // We can de-duplicate normal constants, but not specialization constants. 864 if (auto id = getConstantID(floatAttr)) { 865 return id; 866 } 867 } 868 869 // Process the type for this float literal 870 uint32_t typeID = 0; 871 if (failed(processType(loc, floatAttr.getType(), typeID))) { 872 return 0; 873 } 874 875 auto resultID = getNextID(); 876 APFloat value = floatAttr.getValue(); 877 APInt intValue = value.bitcastToAPInt(); 878 879 auto opcode = 880 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 881 882 if (&value.getSemantics() == &APFloat::IEEEsingle()) { 883 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); 884 (void)encodeInstructionInto(typesGlobalValues, opcode, 885 {typeID, resultID, word}); 886 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { 887 struct DoubleWord { 888 uint32_t word1; 889 uint32_t word2; 890 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); 891 (void)encodeInstructionInto(typesGlobalValues, opcode, 892 {typeID, resultID, words.word1, words.word2}); 893 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { 894 uint32_t word = 895 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); 896 (void)encodeInstructionInto(typesGlobalValues, opcode, 897 {typeID, resultID, word}); 898 } else { 899 std::string valueStr; 900 llvm::raw_string_ostream rss(valueStr); 901 value.print(rss); 902 903 emitError(loc, "cannot serialize ") 904 << floatAttr.getType() << "-typed float literal: " << rss.str(); 905 return 0; 906 } 907 908 if (!isSpec) { 909 constIDMap[floatAttr] = resultID; 910 } 911 return resultID; 912 } 913 914 //===----------------------------------------------------------------------===// 915 // Control flow 916 //===----------------------------------------------------------------------===// 917 918 uint32_t Serializer::getOrCreateBlockID(Block *block) { 919 if (uint32_t id = getBlockID(block)) 920 return id; 921 return blockIDMap[block] = getNextID(); 922 } 923 924 LogicalResult 925 Serializer::processBlock(Block *block, bool omitLabel, 926 function_ref<void()> actionBeforeTerminator) { 927 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); 928 LLVM_DEBUG(block->print(llvm::dbgs())); 929 LLVM_DEBUG(llvm::dbgs() << '\n'); 930 if (!omitLabel) { 931 uint32_t blockID = getOrCreateBlockID(block); 932 LLVM_DEBUG(llvm::dbgs() 933 << "[block] " << block << " (id = " << blockID << ")\n"); 934 935 // Emit OpLabel for this block. 936 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, 937 {blockID}); 938 } 939 940 // Emit OpPhi instructions for block arguments, if any. 941 if (failed(emitPhiForBlockArguments(block))) 942 return failure(); 943 944 // Process each op in this block except the terminator. 945 for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { 946 if (failed(processOperation(&op))) 947 return failure(); 948 } 949 950 // Process the terminator. 951 if (actionBeforeTerminator) 952 actionBeforeTerminator(); 953 if (failed(processOperation(&block->back()))) 954 return failure(); 955 956 return success(); 957 } 958 959 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { 960 // Nothing to do if this block has no arguments or it's the entry block, which 961 // always has the same arguments as the function signature. 962 if (block->args_empty() || block->isEntryBlock()) 963 return success(); 964 965 // If the block has arguments, we need to create SPIR-V OpPhi instructions. 966 // A SPIR-V OpPhi instruction is of the syntax: 967 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair 968 // So we need to collect all predecessor blocks and the arguments they send 969 // to this block. 970 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors; 971 for (Block *predecessor : block->getPredecessors()) { 972 auto *terminator = predecessor->getTerminator(); 973 // The predecessor here is the immediate one according to MLIR's IR 974 // structure. It does not directly map to the incoming parent block for the 975 // OpPhi instructions at SPIR-V binary level. This is because structured 976 // control flow ops are serialized to multiple SPIR-V blocks. If there is a 977 // spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the 978 // branch op jumping to the OpPhi's block then resides in the previous 979 // structured control flow op's merge block. 980 predecessor = getPhiIncomingBlock(predecessor); 981 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { 982 predecessors.emplace_back(predecessor, branchOp.getOperands()); 983 } else if (auto branchCondOp = 984 dyn_cast<spirv::BranchConditionalOp>(terminator)) { 985 Optional<OperandRange> blockOperands; 986 987 for (auto successorIdx : 988 llvm::seq<unsigned>(0, predecessor->getNumSuccessors())) 989 if (predecessor->getSuccessors()[successorIdx] == block) { 990 blockOperands = branchCondOp.getSuccessorOperands(successorIdx); 991 break; 992 } 993 994 assert(blockOperands && !blockOperands->empty() && 995 "expected non-empty block operand range"); 996 predecessors.emplace_back(predecessor, *blockOperands); 997 } else { 998 return terminator->emitError("unimplemented terminator for Phi creation"); 999 } 1000 } 1001 1002 // Then create OpPhi instruction for each of the block argument. 1003 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) { 1004 BlockArgument arg = block->getArgument(argIndex); 1005 1006 // Get the type <id> and result <id> for this OpPhi instruction. 1007 uint32_t phiTypeID = 0; 1008 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) 1009 return failure(); 1010 uint32_t phiID = getNextID(); 1011 1012 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' 1013 << arg << " (id = " << phiID << ")\n"); 1014 1015 // Prepare the (value <id>, parent block <id>) pairs. 1016 SmallVector<uint32_t, 8> phiArgs; 1017 phiArgs.push_back(phiTypeID); 1018 phiArgs.push_back(phiID); 1019 1020 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) { 1021 Value value = predecessors[predIndex].second[argIndex]; 1022 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); 1023 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId 1024 << ") value " << value << ' '); 1025 // Each pair is a value <id> ... 1026 uint32_t valueId = getValueID(value); 1027 if (valueId == 0) { 1028 // The op generating this value hasn't been visited yet so we don't have 1029 // an <id> assigned yet. Record this to fix up later. 1030 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); 1031 deferredPhiValues[value].push_back(functionBody.size() + 1 + 1032 phiArgs.size()); 1033 } else { 1034 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); 1035 } 1036 phiArgs.push_back(valueId); 1037 // ... and a parent block <id>. 1038 phiArgs.push_back(predBlockId); 1039 } 1040 1041 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); 1042 valueIDMap[arg] = phiID; 1043 } 1044 1045 return success(); 1046 } 1047 1048 //===----------------------------------------------------------------------===// 1049 // Operation 1050 //===----------------------------------------------------------------------===// 1051 1052 LogicalResult Serializer::encodeExtensionInstruction( 1053 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, 1054 ArrayRef<uint32_t> operands) { 1055 // Check if the extension has been imported. 1056 auto &setID = extendedInstSetIDMap[extensionSetName]; 1057 if (!setID) { 1058 setID = getNextID(); 1059 SmallVector<uint32_t, 16> importOperands; 1060 importOperands.push_back(setID); 1061 if (failed( 1062 spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || 1063 failed(encodeInstructionInto( 1064 extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { 1065 return failure(); 1066 } 1067 } 1068 1069 // The first two operands are the result type <id> and result <id>. The set 1070 // <id> and the opcode need to be insert after this. 1071 if (operands.size() < 2) { 1072 return op->emitError("extended instructions must have a result encoding"); 1073 } 1074 SmallVector<uint32_t, 8> extInstOperands; 1075 extInstOperands.reserve(operands.size() + 2); 1076 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); 1077 extInstOperands.push_back(setID); 1078 extInstOperands.push_back(extensionOpcode); 1079 extInstOperands.append(std::next(operands.begin(), 2), operands.end()); 1080 return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, 1081 extInstOperands); 1082 } 1083 1084 LogicalResult Serializer::processOperation(Operation *opInst) { 1085 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); 1086 1087 // First dispatch the ops that do not directly mirror an instruction from 1088 // the SPIR-V spec. 1089 return TypeSwitch<Operation *, LogicalResult>(opInst) 1090 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) 1091 .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) 1092 .Case([&](spirv::BranchConditionalOp op) { 1093 return processBranchConditionalOp(op); 1094 }) 1095 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) 1096 .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) 1097 .Case([&](spirv::GlobalVariableOp op) { 1098 return processGlobalVariableOp(op); 1099 }) 1100 .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) 1101 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) 1102 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) 1103 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) 1104 .Case([&](spirv::SpecConstantCompositeOp op) { 1105 return processSpecConstantCompositeOp(op); 1106 }) 1107 .Case([&](spirv::SpecConstantOperationOp op) { 1108 return processSpecConstantOperationOp(op); 1109 }) 1110 .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) 1111 .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) 1112 1113 // Then handle all the ops that directly mirror SPIR-V instructions with 1114 // auto-generated methods. 1115 .Default( 1116 [&](Operation *op) { return dispatchToAutogenSerialization(op); }); 1117 } 1118 1119 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, 1120 StringRef extInstSet, 1121 uint32_t opcode) { 1122 SmallVector<uint32_t, 4> operands; 1123 Location loc = op->getLoc(); 1124 1125 uint32_t resultID = 0; 1126 if (op->getNumResults() != 0) { 1127 uint32_t resultTypeID = 0; 1128 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) 1129 return failure(); 1130 operands.push_back(resultTypeID); 1131 1132 resultID = getNextID(); 1133 operands.push_back(resultID); 1134 valueIDMap[op->getResult(0)] = resultID; 1135 }; 1136 1137 for (Value operand : op->getOperands()) 1138 operands.push_back(getValueID(operand)); 1139 1140 (void)emitDebugLine(functionBody, loc); 1141 1142 if (extInstSet.empty()) { 1143 (void)encodeInstructionInto(functionBody, 1144 static_cast<spirv::Opcode>(opcode), operands); 1145 } else { 1146 (void)encodeExtensionInstruction(op, extInstSet, opcode, operands); 1147 } 1148 1149 if (op->getNumResults() != 0) { 1150 for (auto attr : op->getAttrs()) { 1151 if (failed(processDecoration(loc, resultID, attr))) 1152 return failure(); 1153 } 1154 } 1155 1156 return success(); 1157 } 1158 1159 LogicalResult Serializer::emitDecoration(uint32_t target, 1160 spirv::Decoration decoration, 1161 ArrayRef<uint32_t> params) { 1162 uint32_t wordCount = 3 + params.size(); 1163 decorations.push_back( 1164 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); 1165 decorations.push_back(target); 1166 decorations.push_back(static_cast<uint32_t>(decoration)); 1167 decorations.append(params.begin(), params.end()); 1168 return success(); 1169 } 1170 1171 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, 1172 Location loc) { 1173 if (!emitDebugInfo) 1174 return success(); 1175 1176 if (lastProcessedWasMergeInst) { 1177 lastProcessedWasMergeInst = false; 1178 return success(); 1179 } 1180 1181 auto fileLoc = loc.dyn_cast<FileLineColLoc>(); 1182 if (fileLoc) 1183 (void)encodeInstructionInto( 1184 binary, spirv::Opcode::OpLine, 1185 {fileID, fileLoc.getLine(), fileLoc.getColumn()}); 1186 return success(); 1187 } 1188 } // namespace spirv 1189 } // namespace mlir 1190