1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Serializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Support/LogicalResult.h" 19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/SmallPtrSet.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/ADT/bit.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "spirv-serialization" 28 29 using namespace mlir; 30 31 /// Returns the merge block if the given `op` is a structured control flow op. 32 /// Otherwise returns nullptr. 33 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { 34 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) 35 return selectionOp.getMergeBlock(); 36 if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) 37 return loopOp.getMergeBlock(); 38 return nullptr; 39 } 40 41 /// Given a predecessor `block` for a block with arguments, returns the block 42 /// that should be used as the parent block for SPIR-V OpPhi instructions 43 /// corresponding to the block arguments. 44 static Block *getPhiIncomingBlock(Block *block) { 45 // If the predecessor block in question is the entry block for a 46 // spv.mlir.loop, we jump to this spv.mlir.loop from its enclosing block. 47 if (block->isEntryBlock()) { 48 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { 49 // Then the incoming parent block for OpPhi should be the merge block of 50 // the structured control flow op before this loop. 51 Operation *op = loopOp.getOperation(); 52 while ((op = op->getPrevNode()) != nullptr) 53 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) 54 return incomingBlock; 55 // Or the enclosing block itself if no structured control flow ops 56 // exists before this loop. 57 return loopOp->getBlock(); 58 } 59 } 60 61 // Otherwise, we jump from the given predecessor block. Try to see if there is 62 // a structured control flow op inside it. 63 for (Operation &op : llvm::reverse(block->getOperations())) { 64 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) 65 return incomingBlock; 66 } 67 return block; 68 } 69 70 namespace mlir { 71 namespace spirv { 72 73 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into 74 /// the given `binary` vector. 75 LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, 76 spirv::Opcode op, 77 ArrayRef<uint32_t> operands) { 78 uint32_t wordCount = 1 + operands.size(); 79 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 80 binary.append(operands.begin(), operands.end()); 81 return success(); 82 } 83 84 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) 85 : module(module), mlirBuilder(module.getContext()), 86 emitDebugInfo(emitDebugInfo) {} 87 88 LogicalResult Serializer::serialize() { 89 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); 90 91 if (failed(module.verify())) 92 return failure(); 93 94 // TODO: handle the other sections 95 processCapability(); 96 processExtension(); 97 processMemoryModel(); 98 processDebugInfo(); 99 100 // Iterate over the module body to serialize it. Assumptions are that there is 101 // only one basic block in the moduleOp 102 for (auto &op : module.getBlock()) { 103 if (failed(processOperation(&op))) { 104 return failure(); 105 } 106 } 107 108 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); 109 return success(); 110 } 111 112 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { 113 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + 114 extensions.size() + extendedSets.size() + 115 memoryModel.size() + entryPoints.size() + 116 executionModes.size() + decorations.size() + 117 typesGlobalValues.size() + functions.size(); 118 119 binary.clear(); 120 binary.reserve(moduleSize); 121 122 spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); 123 binary.append(capabilities.begin(), capabilities.end()); 124 binary.append(extensions.begin(), extensions.end()); 125 binary.append(extendedSets.begin(), extendedSets.end()); 126 binary.append(memoryModel.begin(), memoryModel.end()); 127 binary.append(entryPoints.begin(), entryPoints.end()); 128 binary.append(executionModes.begin(), executionModes.end()); 129 binary.append(debug.begin(), debug.end()); 130 binary.append(names.begin(), names.end()); 131 binary.append(decorations.begin(), decorations.end()); 132 binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); 133 binary.append(functions.begin(), functions.end()); 134 } 135 136 #ifndef NDEBUG 137 void Serializer::printValueIDMap(raw_ostream &os) { 138 os << "\n= Value <id> Map =\n\n"; 139 for (auto valueIDPair : valueIDMap) { 140 Value val = valueIDPair.first; 141 os << " " << val << " " 142 << "id = " << valueIDPair.second << ' '; 143 if (auto *op = val.getDefiningOp()) { 144 os << "from op '" << op->getName() << "'"; 145 } else if (auto arg = val.dyn_cast<BlockArgument>()) { 146 Block *block = arg.getOwner(); 147 os << "from argument of block " << block << ' '; 148 os << " in op '" << block->getParentOp()->getName() << "'"; 149 } 150 os << '\n'; 151 } 152 } 153 #endif 154 155 //===----------------------------------------------------------------------===// 156 // Module structure 157 //===----------------------------------------------------------------------===// 158 159 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { 160 auto funcID = funcIDMap.lookup(fnName); 161 if (!funcID) { 162 funcID = getNextID(); 163 funcIDMap[fnName] = funcID; 164 } 165 return funcID; 166 } 167 168 void Serializer::processCapability() { 169 for (auto cap : module.vce_triple()->getCapabilities()) 170 (void)encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, 171 {static_cast<uint32_t>(cap)}); 172 } 173 174 void Serializer::processDebugInfo() { 175 if (!emitDebugInfo) 176 return; 177 auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>(); 178 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>"; 179 fileID = getNextID(); 180 SmallVector<uint32_t, 16> operands; 181 operands.push_back(fileID); 182 (void)spirv::encodeStringLiteralInto(operands, fileName); 183 (void)encodeInstructionInto(debug, spirv::Opcode::OpString, operands); 184 // TODO: Encode more debug instructions. 185 } 186 187 void Serializer::processExtension() { 188 llvm::SmallVector<uint32_t, 16> extName; 189 for (spirv::Extension ext : module.vce_triple()->getExtensions()) { 190 extName.clear(); 191 (void)spirv::encodeStringLiteralInto(extName, 192 spirv::stringifyExtension(ext)); 193 (void)encodeInstructionInto(extensions, spirv::Opcode::OpExtension, 194 extName); 195 } 196 } 197 198 void Serializer::processMemoryModel() { 199 uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt(); 200 uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt(); 201 202 (void)encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, 203 {am, mm}); 204 } 205 206 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, 207 NamedAttribute attr) { 208 auto attrName = attr.first.strref(); 209 auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); 210 auto decoration = spirv::symbolizeDecoration(decorationName); 211 if (!decoration) { 212 return emitError( 213 loc, "non-argument attributes expected to have snake-case-ified " 214 "decoration name, unhandled attribute with name : ") 215 << attrName; 216 } 217 SmallVector<uint32_t, 1> args; 218 switch (decoration.getValue()) { 219 case spirv::Decoration::Binding: 220 case spirv::Decoration::DescriptorSet: 221 case spirv::Decoration::Location: 222 if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) { 223 args.push_back(intAttr.getValue().getZExtValue()); 224 break; 225 } 226 return emitError(loc, "expected integer attribute for ") << attrName; 227 case spirv::Decoration::BuiltIn: 228 if (auto strAttr = attr.second.dyn_cast<StringAttr>()) { 229 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); 230 if (enumVal) { 231 args.push_back(static_cast<uint32_t>(enumVal.getValue())); 232 break; 233 } 234 return emitError(loc, "invalid ") 235 << attrName << " attribute " << strAttr.getValue(); 236 } 237 return emitError(loc, "expected string attribute for ") << attrName; 238 case spirv::Decoration::Aliased: 239 case spirv::Decoration::Flat: 240 case spirv::Decoration::NonReadable: 241 case spirv::Decoration::NonWritable: 242 case spirv::Decoration::NoPerspective: 243 case spirv::Decoration::Restrict: 244 // For unit attributes, the args list has no values so we do nothing 245 if (auto unitAttr = attr.second.dyn_cast<UnitAttr>()) 246 break; 247 return emitError(loc, "expected unit attribute for ") << attrName; 248 default: 249 return emitError(loc, "unhandled decoration ") << decorationName; 250 } 251 return emitDecoration(resultID, decoration.getValue(), args); 252 } 253 254 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { 255 assert(!name.empty() && "unexpected empty string for OpName"); 256 257 SmallVector<uint32_t, 4> nameOperands; 258 nameOperands.push_back(resultID); 259 if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { 260 return failure(); 261 } 262 return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); 263 } 264 265 template <> 266 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( 267 Location loc, spirv::ArrayType type, uint32_t resultID) { 268 if (unsigned stride = type.getArrayStride()) { 269 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 270 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 271 } 272 return success(); 273 } 274 275 template <> 276 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( 277 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) { 278 if (unsigned stride = type.getArrayStride()) { 279 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 280 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 281 } 282 return success(); 283 } 284 285 LogicalResult Serializer::processMemberDecoration( 286 uint32_t structID, 287 const spirv::StructType::MemberDecorationInfo &memberDecoration) { 288 SmallVector<uint32_t, 4> args( 289 {structID, memberDecoration.memberIndex, 290 static_cast<uint32_t>(memberDecoration.decoration)}); 291 if (memberDecoration.hasValue) { 292 args.push_back(memberDecoration.decorationValue); 293 } 294 return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, 295 args); 296 } 297 298 //===----------------------------------------------------------------------===// 299 // Type 300 //===----------------------------------------------------------------------===// 301 302 // According to the SPIR-V spec "Validation Rules for Shader Capabilities": 303 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and 304 // PushConstant Storage Classes must be explicitly laid out." 305 bool Serializer::isInterfaceStructPtrType(Type type) const { 306 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 307 switch (ptrType.getStorageClass()) { 308 case spirv::StorageClass::PhysicalStorageBuffer: 309 case spirv::StorageClass::PushConstant: 310 case spirv::StorageClass::StorageBuffer: 311 case spirv::StorageClass::Uniform: 312 return ptrType.getPointeeType().isa<spirv::StructType>(); 313 default: 314 break; 315 } 316 } 317 return false; 318 } 319 320 LogicalResult Serializer::processType(Location loc, Type type, 321 uint32_t &typeID) { 322 // Maintains a set of names for nested identified struct types. This is used 323 // to properly serialize recursive references. 324 llvm::SetVector<StringRef> serializationCtx; 325 return processTypeImpl(loc, type, typeID, serializationCtx); 326 } 327 328 LogicalResult 329 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, 330 llvm::SetVector<StringRef> &serializationCtx) { 331 typeID = getTypeID(type); 332 if (typeID) { 333 return success(); 334 } 335 typeID = getNextID(); 336 SmallVector<uint32_t, 4> operands; 337 338 operands.push_back(typeID); 339 auto typeEnum = spirv::Opcode::OpTypeVoid; 340 bool deferSerialization = false; 341 342 if ((type.isa<FunctionType>() && 343 succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum, 344 operands))) || 345 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, 346 deferSerialization, serializationCtx))) { 347 if (deferSerialization) 348 return success(); 349 350 typeIDMap[type] = typeID; 351 352 if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) 353 return failure(); 354 355 if (recursiveStructInfos.count(type) != 0) { 356 // This recursive struct type is emitted already, now the OpTypePointer 357 // instructions referring to recursive references are emitted as well. 358 for (auto &ptrInfo : recursiveStructInfos[type]) { 359 // TODO: This might not work if more than 1 recursive reference is 360 // present in the struct. 361 SmallVector<uint32_t, 4> ptrOperands; 362 ptrOperands.push_back(ptrInfo.pointerTypeID); 363 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass)); 364 ptrOperands.push_back(typeIDMap[type]); 365 366 if (failed(encodeInstructionInto( 367 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) 368 return failure(); 369 } 370 371 recursiveStructInfos[type].clear(); 372 } 373 374 return success(); 375 } 376 377 return failure(); 378 } 379 380 LogicalResult Serializer::prepareBasicType( 381 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, 382 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, 383 llvm::SetVector<StringRef> &serializationCtx) { 384 deferSerialization = false; 385 386 if (isVoidType(type)) { 387 typeEnum = spirv::Opcode::OpTypeVoid; 388 return success(); 389 } 390 391 if (auto intType = type.dyn_cast<IntegerType>()) { 392 if (intType.getWidth() == 1) { 393 typeEnum = spirv::Opcode::OpTypeBool; 394 return success(); 395 } 396 397 typeEnum = spirv::Opcode::OpTypeInt; 398 operands.push_back(intType.getWidth()); 399 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 400 // to preserve or validate. 401 // 0 indicates unsigned, or no signedness semantics 402 // 1 indicates signed semantics." 403 operands.push_back(intType.isSigned() ? 1 : 0); 404 return success(); 405 } 406 407 if (auto floatType = type.dyn_cast<FloatType>()) { 408 typeEnum = spirv::Opcode::OpTypeFloat; 409 operands.push_back(floatType.getWidth()); 410 return success(); 411 } 412 413 if (auto vectorType = type.dyn_cast<VectorType>()) { 414 uint32_t elementTypeID = 0; 415 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, 416 serializationCtx))) { 417 return failure(); 418 } 419 typeEnum = spirv::Opcode::OpTypeVector; 420 operands.push_back(elementTypeID); 421 operands.push_back(vectorType.getNumElements()); 422 return success(); 423 } 424 425 if (auto imageType = type.dyn_cast<spirv::ImageType>()) { 426 typeEnum = spirv::Opcode::OpTypeImage; 427 uint32_t sampledTypeID = 0; 428 if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) 429 return failure(); 430 431 operands.push_back(sampledTypeID); 432 operands.push_back(static_cast<uint32_t>(imageType.getDim())); 433 operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo())); 434 operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo())); 435 operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo())); 436 operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo())); 437 operands.push_back(static_cast<uint32_t>(imageType.getImageFormat())); 438 return success(); 439 } 440 441 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) { 442 typeEnum = spirv::Opcode::OpTypeArray; 443 uint32_t elementTypeID = 0; 444 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, 445 serializationCtx))) { 446 return failure(); 447 } 448 operands.push_back(elementTypeID); 449 if (auto elementCountID = prepareConstantInt( 450 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { 451 operands.push_back(elementCountID); 452 } 453 return processTypeDecoration(loc, arrayType, resultID); 454 } 455 456 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 457 uint32_t pointeeTypeID = 0; 458 spirv::StructType pointeeStruct = 459 ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 460 461 if (pointeeStruct && pointeeStruct.isIdentified() && 462 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { 463 // A recursive reference to an enclosing struct is found. 464 // 465 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage 466 // class as operands. 467 SmallVector<uint32_t, 2> forwardPtrOperands; 468 forwardPtrOperands.push_back(resultID); 469 forwardPtrOperands.push_back( 470 static_cast<uint32_t>(ptrType.getStorageClass())); 471 472 (void)encodeInstructionInto(typesGlobalValues, 473 spirv::Opcode::OpTypeForwardPointer, 474 forwardPtrOperands); 475 476 // 2. Find the pointee (enclosing) struct. 477 auto structType = spirv::StructType::getIdentified( 478 module.getContext(), pointeeStruct.getIdentifier()); 479 480 if (!structType) 481 return failure(); 482 483 // 3. Mark the OpTypePointer that is supposed to be emitted by this call 484 // as deferred. 485 deferSerialization = true; 486 487 // 4. Record the info needed to emit the deferred OpTypePointer 488 // instruction when the enclosing struct is completely serialized. 489 recursiveStructInfos[structType].push_back( 490 {resultID, ptrType.getStorageClass()}); 491 } else { 492 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, 493 serializationCtx))) 494 return failure(); 495 } 496 497 typeEnum = spirv::Opcode::OpTypePointer; 498 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); 499 operands.push_back(pointeeTypeID); 500 return success(); 501 } 502 503 if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) { 504 uint32_t elementTypeID = 0; 505 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), 506 elementTypeID, serializationCtx))) { 507 return failure(); 508 } 509 typeEnum = spirv::Opcode::OpTypeRuntimeArray; 510 operands.push_back(elementTypeID); 511 return processTypeDecoration(loc, runtimeArrayType, resultID); 512 } 513 514 if (auto sampledImageType = type.dyn_cast<spirv::SampledImageType>()) { 515 typeEnum = spirv::Opcode::OpTypeSampledImage; 516 uint32_t imageTypeID = 0; 517 if (failed( 518 processType(loc, sampledImageType.getImageType(), imageTypeID))) { 519 return failure(); 520 } 521 operands.push_back(imageTypeID); 522 return success(); 523 } 524 525 if (auto structType = type.dyn_cast<spirv::StructType>()) { 526 if (structType.isIdentified()) { 527 (void)processName(resultID, structType.getIdentifier()); 528 serializationCtx.insert(structType.getIdentifier()); 529 } 530 531 bool hasOffset = structType.hasOffset(); 532 for (auto elementIndex : 533 llvm::seq<uint32_t>(0, structType.getNumElements())) { 534 uint32_t elementTypeID = 0; 535 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), 536 elementTypeID, serializationCtx))) { 537 return failure(); 538 } 539 operands.push_back(elementTypeID); 540 if (hasOffset) { 541 // Decorate each struct member with an offset 542 spirv::StructType::MemberDecorationInfo offsetDecoration{ 543 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, 544 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; 545 if (failed(processMemberDecoration(resultID, offsetDecoration))) { 546 return emitError(loc, "cannot decorate ") 547 << elementIndex << "-th member of " << structType 548 << " with its offset"; 549 } 550 } 551 } 552 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 553 structType.getMemberDecorations(memberDecorations); 554 555 for (auto &memberDecoration : memberDecorations) { 556 if (failed(processMemberDecoration(resultID, memberDecoration))) { 557 return emitError(loc, "cannot decorate ") 558 << static_cast<uint32_t>(memberDecoration.memberIndex) 559 << "-th member of " << structType << " with " 560 << stringifyDecoration(memberDecoration.decoration); 561 } 562 } 563 564 typeEnum = spirv::Opcode::OpTypeStruct; 565 566 if (structType.isIdentified()) 567 serializationCtx.remove(structType.getIdentifier()); 568 569 return success(); 570 } 571 572 if (auto cooperativeMatrixType = 573 type.dyn_cast<spirv::CooperativeMatrixNVType>()) { 574 uint32_t elementTypeID = 0; 575 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), 576 elementTypeID, serializationCtx))) { 577 return failure(); 578 } 579 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; 580 auto getConstantOp = [&](uint32_t id) { 581 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); 582 return prepareConstantInt(loc, attr); 583 }; 584 operands.push_back(elementTypeID); 585 operands.push_back( 586 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope()))); 587 operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); 588 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); 589 return success(); 590 } 591 592 if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) { 593 uint32_t elementTypeID = 0; 594 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, 595 serializationCtx))) { 596 return failure(); 597 } 598 typeEnum = spirv::Opcode::OpTypeMatrix; 599 operands.push_back(elementTypeID); 600 operands.push_back(matrixType.getNumColumns()); 601 return success(); 602 } 603 604 // TODO: Handle other types. 605 return emitError(loc, "unhandled type in serialization: ") << type; 606 } 607 608 LogicalResult 609 Serializer::prepareFunctionType(Location loc, FunctionType type, 610 spirv::Opcode &typeEnum, 611 SmallVectorImpl<uint32_t> &operands) { 612 typeEnum = spirv::Opcode::OpTypeFunction; 613 assert(type.getNumResults() <= 1 && 614 "serialization supports only a single return value"); 615 uint32_t resultID = 0; 616 if (failed(processType( 617 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), 618 resultID))) { 619 return failure(); 620 } 621 operands.push_back(resultID); 622 for (auto &res : type.getInputs()) { 623 uint32_t argTypeID = 0; 624 if (failed(processType(loc, res, argTypeID))) { 625 return failure(); 626 } 627 operands.push_back(argTypeID); 628 } 629 return success(); 630 } 631 632 //===----------------------------------------------------------------------===// 633 // Constant 634 //===----------------------------------------------------------------------===// 635 636 uint32_t Serializer::prepareConstant(Location loc, Type constType, 637 Attribute valueAttr) { 638 if (auto id = prepareConstantScalar(loc, valueAttr)) { 639 return id; 640 } 641 642 // This is a composite literal. We need to handle each component separately 643 // and then emit an OpConstantComposite for the whole. 644 645 if (auto id = getConstantID(valueAttr)) { 646 return id; 647 } 648 649 uint32_t typeID = 0; 650 if (failed(processType(loc, constType, typeID))) { 651 return 0; 652 } 653 654 uint32_t resultID = 0; 655 if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) { 656 int rank = attr.getType().dyn_cast<ShapedType>().getRank(); 657 SmallVector<uint64_t, 4> index(rank); 658 resultID = prepareDenseElementsConstant(loc, constType, attr, 659 /*dim=*/0, index); 660 } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) { 661 resultID = prepareArrayConstant(loc, constType, arrayAttr); 662 } 663 664 if (resultID == 0) { 665 emitError(loc, "cannot serialize attribute: ") << valueAttr; 666 return 0; 667 } 668 669 constIDMap[valueAttr] = resultID; 670 return resultID; 671 } 672 673 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, 674 ArrayAttr attr) { 675 uint32_t typeID = 0; 676 if (failed(processType(loc, constType, typeID))) { 677 return 0; 678 } 679 680 uint32_t resultID = getNextID(); 681 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 682 operands.reserve(attr.size() + 2); 683 auto elementType = constType.cast<spirv::ArrayType>().getElementType(); 684 for (Attribute elementAttr : attr) { 685 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { 686 operands.push_back(elementID); 687 } else { 688 return 0; 689 } 690 } 691 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 692 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 693 694 return resultID; 695 } 696 697 // TODO: Turn the below function into iterative function, instead of 698 // recursive function. 699 uint32_t 700 Serializer::prepareDenseElementsConstant(Location loc, Type constType, 701 DenseElementsAttr valueAttr, int dim, 702 MutableArrayRef<uint64_t> index) { 703 auto shapedType = valueAttr.getType().dyn_cast<ShapedType>(); 704 assert(dim <= shapedType.getRank()); 705 if (shapedType.getRank() == dim) { 706 if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) { 707 return attr.getType().getElementType().isInteger(1) 708 ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index)) 709 : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index)); 710 } 711 if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) { 712 return prepareConstantFp(loc, attr.getValue<FloatAttr>(index)); 713 } 714 return 0; 715 } 716 717 uint32_t typeID = 0; 718 if (failed(processType(loc, constType, typeID))) { 719 return 0; 720 } 721 722 uint32_t resultID = getNextID(); 723 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 724 operands.reserve(shapedType.getDimSize(dim) + 2); 725 auto elementType = constType.cast<spirv::CompositeType>().getElementType(0); 726 for (int i = 0; i < shapedType.getDimSize(dim); ++i) { 727 index[dim] = i; 728 if (auto elementID = prepareDenseElementsConstant( 729 loc, elementType, valueAttr, dim + 1, index)) { 730 operands.push_back(elementID); 731 } else { 732 return 0; 733 } 734 } 735 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 736 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 737 738 return resultID; 739 } 740 741 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, 742 bool isSpec) { 743 if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) { 744 return prepareConstantFp(loc, floatAttr, isSpec); 745 } 746 if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) { 747 return prepareConstantBool(loc, boolAttr, isSpec); 748 } 749 if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) { 750 return prepareConstantInt(loc, intAttr, isSpec); 751 } 752 753 return 0; 754 } 755 756 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, 757 bool isSpec) { 758 if (!isSpec) { 759 // We can de-duplicate normal constants, but not specialization constants. 760 if (auto id = getConstantID(boolAttr)) { 761 return id; 762 } 763 } 764 765 // Process the type for this bool literal 766 uint32_t typeID = 0; 767 if (failed(processType(loc, boolAttr.getType(), typeID))) { 768 return 0; 769 } 770 771 auto resultID = getNextID(); 772 auto opcode = boolAttr.getValue() 773 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue 774 : spirv::Opcode::OpConstantTrue) 775 : (isSpec ? spirv::Opcode::OpSpecConstantFalse 776 : spirv::Opcode::OpConstantFalse); 777 (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); 778 779 if (!isSpec) { 780 constIDMap[boolAttr] = resultID; 781 } 782 return resultID; 783 } 784 785 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, 786 bool isSpec) { 787 if (!isSpec) { 788 // We can de-duplicate normal constants, but not specialization constants. 789 if (auto id = getConstantID(intAttr)) { 790 return id; 791 } 792 } 793 794 // Process the type for this integer literal 795 uint32_t typeID = 0; 796 if (failed(processType(loc, intAttr.getType(), typeID))) { 797 return 0; 798 } 799 800 auto resultID = getNextID(); 801 APInt value = intAttr.getValue(); 802 unsigned bitwidth = value.getBitWidth(); 803 bool isSigned = value.isSignedIntN(bitwidth); 804 805 auto opcode = 806 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 807 808 // According to SPIR-V spec, "When the type's bit width is less than 32-bits, 809 // the literal's value appears in the low-order bits of the word, and the 810 // high-order bits must be 0 for a floating-point type, or 0 for an integer 811 // type with Signedness of 0, or sign extended when Signedness is 1." 812 if (bitwidth == 32 || bitwidth == 16) { 813 uint32_t word = 0; 814 if (isSigned) { 815 word = static_cast<int32_t>(value.getSExtValue()); 816 } else { 817 word = static_cast<uint32_t>(value.getZExtValue()); 818 } 819 (void)encodeInstructionInto(typesGlobalValues, opcode, 820 {typeID, resultID, word}); 821 } 822 // According to SPIR-V spec: "When the type's bit width is larger than one 823 // word, the literal’s low-order words appear first." 824 else if (bitwidth == 64) { 825 struct DoubleWord { 826 uint32_t word1; 827 uint32_t word2; 828 } words; 829 if (isSigned) { 830 words = llvm::bit_cast<DoubleWord>(value.getSExtValue()); 831 } else { 832 words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); 833 } 834 (void)encodeInstructionInto(typesGlobalValues, opcode, 835 {typeID, resultID, words.word1, words.word2}); 836 } else { 837 std::string valueStr; 838 llvm::raw_string_ostream rss(valueStr); 839 value.print(rss, /*isSigned=*/false); 840 841 emitError(loc, "cannot serialize ") 842 << bitwidth << "-bit integer literal: " << rss.str(); 843 return 0; 844 } 845 846 if (!isSpec) { 847 constIDMap[intAttr] = resultID; 848 } 849 return resultID; 850 } 851 852 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, 853 bool isSpec) { 854 if (!isSpec) { 855 // We can de-duplicate normal constants, but not specialization constants. 856 if (auto id = getConstantID(floatAttr)) { 857 return id; 858 } 859 } 860 861 // Process the type for this float literal 862 uint32_t typeID = 0; 863 if (failed(processType(loc, floatAttr.getType(), typeID))) { 864 return 0; 865 } 866 867 auto resultID = getNextID(); 868 APFloat value = floatAttr.getValue(); 869 APInt intValue = value.bitcastToAPInt(); 870 871 auto opcode = 872 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 873 874 if (&value.getSemantics() == &APFloat::IEEEsingle()) { 875 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); 876 (void)encodeInstructionInto(typesGlobalValues, opcode, 877 {typeID, resultID, word}); 878 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { 879 struct DoubleWord { 880 uint32_t word1; 881 uint32_t word2; 882 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); 883 (void)encodeInstructionInto(typesGlobalValues, opcode, 884 {typeID, resultID, words.word1, words.word2}); 885 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { 886 uint32_t word = 887 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); 888 (void)encodeInstructionInto(typesGlobalValues, opcode, 889 {typeID, resultID, word}); 890 } else { 891 std::string valueStr; 892 llvm::raw_string_ostream rss(valueStr); 893 value.print(rss); 894 895 emitError(loc, "cannot serialize ") 896 << floatAttr.getType() << "-typed float literal: " << rss.str(); 897 return 0; 898 } 899 900 if (!isSpec) { 901 constIDMap[floatAttr] = resultID; 902 } 903 return resultID; 904 } 905 906 //===----------------------------------------------------------------------===// 907 // Control flow 908 //===----------------------------------------------------------------------===// 909 910 uint32_t Serializer::getOrCreateBlockID(Block *block) { 911 if (uint32_t id = getBlockID(block)) 912 return id; 913 return blockIDMap[block] = getNextID(); 914 } 915 916 LogicalResult 917 Serializer::processBlock(Block *block, bool omitLabel, 918 function_ref<void()> actionBeforeTerminator) { 919 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); 920 LLVM_DEBUG(block->print(llvm::dbgs())); 921 LLVM_DEBUG(llvm::dbgs() << '\n'); 922 if (!omitLabel) { 923 uint32_t blockID = getOrCreateBlockID(block); 924 LLVM_DEBUG(llvm::dbgs() 925 << "[block] " << block << " (id = " << blockID << ")\n"); 926 927 // Emit OpLabel for this block. 928 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, 929 {blockID}); 930 } 931 932 // Emit OpPhi instructions for block arguments, if any. 933 if (failed(emitPhiForBlockArguments(block))) 934 return failure(); 935 936 // Process each op in this block except the terminator. 937 for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { 938 if (failed(processOperation(&op))) 939 return failure(); 940 } 941 942 // Process the terminator. 943 if (actionBeforeTerminator) 944 actionBeforeTerminator(); 945 if (failed(processOperation(&block->back()))) 946 return failure(); 947 948 return success(); 949 } 950 951 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { 952 // Nothing to do if this block has no arguments or it's the entry block, which 953 // always has the same arguments as the function signature. 954 if (block->args_empty() || block->isEntryBlock()) 955 return success(); 956 957 // If the block has arguments, we need to create SPIR-V OpPhi instructions. 958 // A SPIR-V OpPhi instruction is of the syntax: 959 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair 960 // So we need to collect all predecessor blocks and the arguments they send 961 // to this block. 962 SmallVector<std::pair<Block *, Operation::operand_iterator>, 4> predecessors; 963 for (Block *predecessor : block->getPredecessors()) { 964 auto *terminator = predecessor->getTerminator(); 965 // The predecessor here is the immediate one according to MLIR's IR 966 // structure. It does not directly map to the incoming parent block for the 967 // OpPhi instructions at SPIR-V binary level. This is because structured 968 // control flow ops are serialized to multiple SPIR-V blocks. If there is a 969 // spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the 970 // branch op jumping to the OpPhi's block then resides in the previous 971 // structured control flow op's merge block. 972 predecessor = getPhiIncomingBlock(predecessor); 973 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { 974 predecessors.emplace_back(predecessor, branchOp.operand_begin()); 975 } else { 976 return terminator->emitError("unimplemented terminator for Phi creation"); 977 } 978 } 979 980 // Then create OpPhi instruction for each of the block argument. 981 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) { 982 BlockArgument arg = block->getArgument(argIndex); 983 984 // Get the type <id> and result <id> for this OpPhi instruction. 985 uint32_t phiTypeID = 0; 986 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) 987 return failure(); 988 uint32_t phiID = getNextID(); 989 990 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' 991 << arg << " (id = " << phiID << ")\n"); 992 993 // Prepare the (value <id>, parent block <id>) pairs. 994 SmallVector<uint32_t, 8> phiArgs; 995 phiArgs.push_back(phiTypeID); 996 phiArgs.push_back(phiID); 997 998 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) { 999 Value value = *(predecessors[predIndex].second + argIndex); 1000 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); 1001 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId 1002 << ") value " << value << ' '); 1003 // Each pair is a value <id> ... 1004 uint32_t valueId = getValueID(value); 1005 if (valueId == 0) { 1006 // The op generating this value hasn't been visited yet so we don't have 1007 // an <id> assigned yet. Record this to fix up later. 1008 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); 1009 deferredPhiValues[value].push_back(functionBody.size() + 1 + 1010 phiArgs.size()); 1011 } else { 1012 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); 1013 } 1014 phiArgs.push_back(valueId); 1015 // ... and a parent block <id>. 1016 phiArgs.push_back(predBlockId); 1017 } 1018 1019 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); 1020 valueIDMap[arg] = phiID; 1021 } 1022 1023 return success(); 1024 } 1025 1026 //===----------------------------------------------------------------------===// 1027 // Operation 1028 //===----------------------------------------------------------------------===// 1029 1030 LogicalResult Serializer::encodeExtensionInstruction( 1031 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, 1032 ArrayRef<uint32_t> operands) { 1033 // Check if the extension has been imported. 1034 auto &setID = extendedInstSetIDMap[extensionSetName]; 1035 if (!setID) { 1036 setID = getNextID(); 1037 SmallVector<uint32_t, 16> importOperands; 1038 importOperands.push_back(setID); 1039 if (failed( 1040 spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || 1041 failed(encodeInstructionInto( 1042 extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { 1043 return failure(); 1044 } 1045 } 1046 1047 // The first two operands are the result type <id> and result <id>. The set 1048 // <id> and the opcode need to be insert after this. 1049 if (operands.size() < 2) { 1050 return op->emitError("extended instructions must have a result encoding"); 1051 } 1052 SmallVector<uint32_t, 8> extInstOperands; 1053 extInstOperands.reserve(operands.size() + 2); 1054 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); 1055 extInstOperands.push_back(setID); 1056 extInstOperands.push_back(extensionOpcode); 1057 extInstOperands.append(std::next(operands.begin(), 2), operands.end()); 1058 return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, 1059 extInstOperands); 1060 } 1061 1062 LogicalResult Serializer::processOperation(Operation *opInst) { 1063 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); 1064 1065 // First dispatch the ops that do not directly mirror an instruction from 1066 // the SPIR-V spec. 1067 return TypeSwitch<Operation *, LogicalResult>(opInst) 1068 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) 1069 .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) 1070 .Case([&](spirv::BranchConditionalOp op) { 1071 return processBranchConditionalOp(op); 1072 }) 1073 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) 1074 .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) 1075 .Case([&](spirv::GlobalVariableOp op) { 1076 return processGlobalVariableOp(op); 1077 }) 1078 .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) 1079 .Case([&](spirv::ModuleEndOp) { return success(); }) 1080 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) 1081 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) 1082 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) 1083 .Case([&](spirv::SpecConstantCompositeOp op) { 1084 return processSpecConstantCompositeOp(op); 1085 }) 1086 .Case([&](spirv::SpecConstantOperationOp op) { 1087 return processSpecConstantOperationOp(op); 1088 }) 1089 .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) 1090 .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) 1091 1092 // Then handle all the ops that directly mirror SPIR-V instructions with 1093 // auto-generated methods. 1094 .Default( 1095 [&](Operation *op) { return dispatchToAutogenSerialization(op); }); 1096 } 1097 1098 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, 1099 StringRef extInstSet, 1100 uint32_t opcode) { 1101 SmallVector<uint32_t, 4> operands; 1102 Location loc = op->getLoc(); 1103 1104 uint32_t resultID = 0; 1105 if (op->getNumResults() != 0) { 1106 uint32_t resultTypeID = 0; 1107 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) 1108 return failure(); 1109 operands.push_back(resultTypeID); 1110 1111 resultID = getNextID(); 1112 operands.push_back(resultID); 1113 valueIDMap[op->getResult(0)] = resultID; 1114 }; 1115 1116 for (Value operand : op->getOperands()) 1117 operands.push_back(getValueID(operand)); 1118 1119 (void)emitDebugLine(functionBody, loc); 1120 1121 if (extInstSet.empty()) { 1122 (void)encodeInstructionInto(functionBody, 1123 static_cast<spirv::Opcode>(opcode), operands); 1124 } else { 1125 (void)encodeExtensionInstruction(op, extInstSet, opcode, operands); 1126 } 1127 1128 if (op->getNumResults() != 0) { 1129 for (auto attr : op->getAttrs()) { 1130 if (failed(processDecoration(loc, resultID, attr))) 1131 return failure(); 1132 } 1133 } 1134 1135 return success(); 1136 } 1137 1138 LogicalResult Serializer::emitDecoration(uint32_t target, 1139 spirv::Decoration decoration, 1140 ArrayRef<uint32_t> params) { 1141 uint32_t wordCount = 3 + params.size(); 1142 decorations.push_back( 1143 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); 1144 decorations.push_back(target); 1145 decorations.push_back(static_cast<uint32_t>(decoration)); 1146 decorations.append(params.begin(), params.end()); 1147 return success(); 1148 } 1149 1150 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, 1151 Location loc) { 1152 if (!emitDebugInfo) 1153 return success(); 1154 1155 if (lastProcessedWasMergeInst) { 1156 lastProcessedWasMergeInst = false; 1157 return success(); 1158 } 1159 1160 auto fileLoc = loc.dyn_cast<FileLineColLoc>(); 1161 if (fileLoc) 1162 (void)encodeInstructionInto( 1163 binary, spirv::Opcode::OpLine, 1164 {fileID, fileLoc.getLine(), fileLoc.getColumn()}); 1165 return success(); 1166 } 1167 } // namespace spirv 1168 } // namespace mlir 1169