1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Serializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Support/LogicalResult.h" 19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/SmallPtrSet.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/ADT/bit.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "spirv-serialization" 28 29 using namespace mlir; 30 31 /// Returns the merge block if the given `op` is a structured control flow op. 32 /// Otherwise returns nullptr. 33 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { 34 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) 35 return selectionOp.getMergeBlock(); 36 if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) 37 return loopOp.getMergeBlock(); 38 return nullptr; 39 } 40 41 /// Given a predecessor `block` for a block with arguments, returns the block 42 /// that should be used as the parent block for SPIR-V OpPhi instructions 43 /// corresponding to the block arguments. 44 static Block *getPhiIncomingBlock(Block *block) { 45 // If the predecessor block in question is the entry block for a 46 // spv.mlir.loop, we jump to this spv.mlir.loop from its enclosing block. 47 if (block->isEntryBlock()) { 48 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { 49 // Then the incoming parent block for OpPhi should be the merge block of 50 // the structured control flow op before this loop. 51 Operation *op = loopOp.getOperation(); 52 while ((op = op->getPrevNode()) != nullptr) 53 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) 54 return incomingBlock; 55 // Or the enclosing block itself if no structured control flow ops 56 // exists before this loop. 57 return loopOp->getBlock(); 58 } 59 } 60 61 // Otherwise, we jump from the given predecessor block. Try to see if there is 62 // a structured control flow op inside it. 63 for (Operation &op : llvm::reverse(block->getOperations())) { 64 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) 65 return incomingBlock; 66 } 67 return block; 68 } 69 70 namespace mlir { 71 namespace spirv { 72 73 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into 74 /// the given `binary` vector. 75 LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, 76 spirv::Opcode op, 77 ArrayRef<uint32_t> operands) { 78 uint32_t wordCount = 1 + operands.size(); 79 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 80 binary.append(operands.begin(), operands.end()); 81 return success(); 82 } 83 84 Serializer::Serializer(spirv::ModuleOp module, 85 const SerializationOptions &options) 86 : module(module), mlirBuilder(module.getContext()), options(options) {} 87 88 LogicalResult Serializer::serialize() { 89 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); 90 91 if (failed(module.verify())) 92 return failure(); 93 94 // TODO: handle the other sections 95 processCapability(); 96 processExtension(); 97 processMemoryModel(); 98 processDebugInfo(); 99 100 // Iterate over the module body to serialize it. Assumptions are that there is 101 // only one basic block in the moduleOp 102 for (auto &op : *module.getBody()) { 103 if (failed(processOperation(&op))) { 104 return failure(); 105 } 106 } 107 108 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); 109 return success(); 110 } 111 112 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { 113 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + 114 extensions.size() + extendedSets.size() + 115 memoryModel.size() + entryPoints.size() + 116 executionModes.size() + decorations.size() + 117 typesGlobalValues.size() + functions.size(); 118 119 binary.clear(); 120 binary.reserve(moduleSize); 121 122 spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); 123 binary.append(capabilities.begin(), capabilities.end()); 124 binary.append(extensions.begin(), extensions.end()); 125 binary.append(extendedSets.begin(), extendedSets.end()); 126 binary.append(memoryModel.begin(), memoryModel.end()); 127 binary.append(entryPoints.begin(), entryPoints.end()); 128 binary.append(executionModes.begin(), executionModes.end()); 129 binary.append(debug.begin(), debug.end()); 130 binary.append(names.begin(), names.end()); 131 binary.append(decorations.begin(), decorations.end()); 132 binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); 133 binary.append(functions.begin(), functions.end()); 134 } 135 136 #ifndef NDEBUG 137 void Serializer::printValueIDMap(raw_ostream &os) { 138 os << "\n= Value <id> Map =\n\n"; 139 for (auto valueIDPair : valueIDMap) { 140 Value val = valueIDPair.first; 141 os << " " << val << " " 142 << "id = " << valueIDPair.second << ' '; 143 if (auto *op = val.getDefiningOp()) { 144 os << "from op '" << op->getName() << "'"; 145 } else if (auto arg = val.dyn_cast<BlockArgument>()) { 146 Block *block = arg.getOwner(); 147 os << "from argument of block " << block << ' '; 148 os << " in op '" << block->getParentOp()->getName() << "'"; 149 } 150 os << '\n'; 151 } 152 } 153 #endif 154 155 //===----------------------------------------------------------------------===// 156 // Module structure 157 //===----------------------------------------------------------------------===// 158 159 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { 160 auto funcID = funcIDMap.lookup(fnName); 161 if (!funcID) { 162 funcID = getNextID(); 163 funcIDMap[fnName] = funcID; 164 } 165 return funcID; 166 } 167 168 void Serializer::processCapability() { 169 for (auto cap : module.vce_triple()->getCapabilities()) 170 (void)encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, 171 {static_cast<uint32_t>(cap)}); 172 } 173 174 void Serializer::processDebugInfo() { 175 if (!options.emitDebugInfo) 176 return; 177 auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>(); 178 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>"; 179 fileID = getNextID(); 180 SmallVector<uint32_t, 16> operands; 181 operands.push_back(fileID); 182 (void)spirv::encodeStringLiteralInto(operands, fileName); 183 (void)encodeInstructionInto(debug, spirv::Opcode::OpString, operands); 184 // TODO: Encode more debug instructions. 185 } 186 187 void Serializer::processExtension() { 188 llvm::SmallVector<uint32_t, 16> extName; 189 for (spirv::Extension ext : module.vce_triple()->getExtensions()) { 190 extName.clear(); 191 (void)spirv::encodeStringLiteralInto(extName, 192 spirv::stringifyExtension(ext)); 193 (void)encodeInstructionInto(extensions, spirv::Opcode::OpExtension, 194 extName); 195 } 196 } 197 198 void Serializer::processMemoryModel() { 199 uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt(); 200 uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt(); 201 202 (void)encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, 203 {am, mm}); 204 } 205 206 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, 207 NamedAttribute attr) { 208 auto attrName = attr.getName().strref(); 209 auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); 210 auto decoration = spirv::symbolizeDecoration(decorationName); 211 if (!decoration) { 212 return emitError( 213 loc, "non-argument attributes expected to have snake-case-ified " 214 "decoration name, unhandled attribute with name : ") 215 << attrName; 216 } 217 SmallVector<uint32_t, 1> args; 218 switch (decoration.getValue()) { 219 case spirv::Decoration::Binding: 220 case spirv::Decoration::DescriptorSet: 221 case spirv::Decoration::Location: 222 if (auto intAttr = attr.getValue().dyn_cast<IntegerAttr>()) { 223 args.push_back(intAttr.getValue().getZExtValue()); 224 break; 225 } 226 return emitError(loc, "expected integer attribute for ") << attrName; 227 case spirv::Decoration::BuiltIn: 228 if (auto strAttr = attr.getValue().dyn_cast<StringAttr>()) { 229 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); 230 if (enumVal) { 231 args.push_back(static_cast<uint32_t>(enumVal.getValue())); 232 break; 233 } 234 return emitError(loc, "invalid ") 235 << attrName << " attribute " << strAttr.getValue(); 236 } 237 return emitError(loc, "expected string attribute for ") << attrName; 238 case spirv::Decoration::Aliased: 239 case spirv::Decoration::Flat: 240 case spirv::Decoration::NonReadable: 241 case spirv::Decoration::NonWritable: 242 case spirv::Decoration::NoPerspective: 243 case spirv::Decoration::Restrict: 244 case spirv::Decoration::RelaxedPrecision: 245 // For unit attributes, the args list has no values so we do nothing 246 if (auto unitAttr = attr.getValue().dyn_cast<UnitAttr>()) 247 break; 248 return emitError(loc, "expected unit attribute for ") << attrName; 249 default: 250 return emitError(loc, "unhandled decoration ") << decorationName; 251 } 252 return emitDecoration(resultID, decoration.getValue(), args); 253 } 254 255 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { 256 assert(!name.empty() && "unexpected empty string for OpName"); 257 if (!options.emitSymbolName) 258 return success(); 259 260 SmallVector<uint32_t, 4> nameOperands; 261 nameOperands.push_back(resultID); 262 if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) 263 return failure(); 264 return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); 265 } 266 267 template <> 268 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( 269 Location loc, spirv::ArrayType type, uint32_t resultID) { 270 if (unsigned stride = type.getArrayStride()) { 271 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 272 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 273 } 274 return success(); 275 } 276 277 template <> 278 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( 279 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) { 280 if (unsigned stride = type.getArrayStride()) { 281 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 282 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 283 } 284 return success(); 285 } 286 287 LogicalResult Serializer::processMemberDecoration( 288 uint32_t structID, 289 const spirv::StructType::MemberDecorationInfo &memberDecoration) { 290 SmallVector<uint32_t, 4> args( 291 {structID, memberDecoration.memberIndex, 292 static_cast<uint32_t>(memberDecoration.decoration)}); 293 if (memberDecoration.hasValue) { 294 args.push_back(memberDecoration.decorationValue); 295 } 296 return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, 297 args); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // Type 302 //===----------------------------------------------------------------------===// 303 304 // According to the SPIR-V spec "Validation Rules for Shader Capabilities": 305 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and 306 // PushConstant Storage Classes must be explicitly laid out." 307 bool Serializer::isInterfaceStructPtrType(Type type) const { 308 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 309 switch (ptrType.getStorageClass()) { 310 case spirv::StorageClass::PhysicalStorageBuffer: 311 case spirv::StorageClass::PushConstant: 312 case spirv::StorageClass::StorageBuffer: 313 case spirv::StorageClass::Uniform: 314 return ptrType.getPointeeType().isa<spirv::StructType>(); 315 default: 316 break; 317 } 318 } 319 return false; 320 } 321 322 LogicalResult Serializer::processType(Location loc, Type type, 323 uint32_t &typeID) { 324 // Maintains a set of names for nested identified struct types. This is used 325 // to properly serialize recursive references. 326 SetVector<StringRef> serializationCtx; 327 return processTypeImpl(loc, type, typeID, serializationCtx); 328 } 329 330 LogicalResult 331 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, 332 SetVector<StringRef> &serializationCtx) { 333 typeID = getTypeID(type); 334 if (typeID) { 335 return success(); 336 } 337 typeID = getNextID(); 338 SmallVector<uint32_t, 4> operands; 339 340 operands.push_back(typeID); 341 auto typeEnum = spirv::Opcode::OpTypeVoid; 342 bool deferSerialization = false; 343 344 if ((type.isa<FunctionType>() && 345 succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum, 346 operands))) || 347 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, 348 deferSerialization, serializationCtx))) { 349 if (deferSerialization) 350 return success(); 351 352 typeIDMap[type] = typeID; 353 354 if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) 355 return failure(); 356 357 if (recursiveStructInfos.count(type) != 0) { 358 // This recursive struct type is emitted already, now the OpTypePointer 359 // instructions referring to recursive references are emitted as well. 360 for (auto &ptrInfo : recursiveStructInfos[type]) { 361 // TODO: This might not work if more than 1 recursive reference is 362 // present in the struct. 363 SmallVector<uint32_t, 4> ptrOperands; 364 ptrOperands.push_back(ptrInfo.pointerTypeID); 365 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass)); 366 ptrOperands.push_back(typeIDMap[type]); 367 368 if (failed(encodeInstructionInto( 369 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) 370 return failure(); 371 } 372 373 recursiveStructInfos[type].clear(); 374 } 375 376 return success(); 377 } 378 379 return failure(); 380 } 381 382 LogicalResult Serializer::prepareBasicType( 383 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, 384 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, 385 SetVector<StringRef> &serializationCtx) { 386 deferSerialization = false; 387 388 if (isVoidType(type)) { 389 typeEnum = spirv::Opcode::OpTypeVoid; 390 return success(); 391 } 392 393 if (auto intType = type.dyn_cast<IntegerType>()) { 394 if (intType.getWidth() == 1) { 395 typeEnum = spirv::Opcode::OpTypeBool; 396 return success(); 397 } 398 399 typeEnum = spirv::Opcode::OpTypeInt; 400 operands.push_back(intType.getWidth()); 401 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 402 // to preserve or validate. 403 // 0 indicates unsigned, or no signedness semantics 404 // 1 indicates signed semantics." 405 operands.push_back(intType.isSigned() ? 1 : 0); 406 return success(); 407 } 408 409 if (auto floatType = type.dyn_cast<FloatType>()) { 410 typeEnum = spirv::Opcode::OpTypeFloat; 411 operands.push_back(floatType.getWidth()); 412 return success(); 413 } 414 415 if (auto vectorType = type.dyn_cast<VectorType>()) { 416 uint32_t elementTypeID = 0; 417 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, 418 serializationCtx))) { 419 return failure(); 420 } 421 typeEnum = spirv::Opcode::OpTypeVector; 422 operands.push_back(elementTypeID); 423 operands.push_back(vectorType.getNumElements()); 424 return success(); 425 } 426 427 if (auto imageType = type.dyn_cast<spirv::ImageType>()) { 428 typeEnum = spirv::Opcode::OpTypeImage; 429 uint32_t sampledTypeID = 0; 430 if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) 431 return failure(); 432 433 operands.push_back(sampledTypeID); 434 operands.push_back(static_cast<uint32_t>(imageType.getDim())); 435 operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo())); 436 operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo())); 437 operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo())); 438 operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo())); 439 operands.push_back(static_cast<uint32_t>(imageType.getImageFormat())); 440 return success(); 441 } 442 443 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) { 444 typeEnum = spirv::Opcode::OpTypeArray; 445 uint32_t elementTypeID = 0; 446 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, 447 serializationCtx))) { 448 return failure(); 449 } 450 operands.push_back(elementTypeID); 451 if (auto elementCountID = prepareConstantInt( 452 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { 453 operands.push_back(elementCountID); 454 } 455 return processTypeDecoration(loc, arrayType, resultID); 456 } 457 458 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 459 uint32_t pointeeTypeID = 0; 460 spirv::StructType pointeeStruct = 461 ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 462 463 if (pointeeStruct && pointeeStruct.isIdentified() && 464 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { 465 // A recursive reference to an enclosing struct is found. 466 // 467 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage 468 // class as operands. 469 SmallVector<uint32_t, 2> forwardPtrOperands; 470 forwardPtrOperands.push_back(resultID); 471 forwardPtrOperands.push_back( 472 static_cast<uint32_t>(ptrType.getStorageClass())); 473 474 (void)encodeInstructionInto(typesGlobalValues, 475 spirv::Opcode::OpTypeForwardPointer, 476 forwardPtrOperands); 477 478 // 2. Find the pointee (enclosing) struct. 479 auto structType = spirv::StructType::getIdentified( 480 module.getContext(), pointeeStruct.getIdentifier()); 481 482 if (!structType) 483 return failure(); 484 485 // 3. Mark the OpTypePointer that is supposed to be emitted by this call 486 // as deferred. 487 deferSerialization = true; 488 489 // 4. Record the info needed to emit the deferred OpTypePointer 490 // instruction when the enclosing struct is completely serialized. 491 recursiveStructInfos[structType].push_back( 492 {resultID, ptrType.getStorageClass()}); 493 } else { 494 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, 495 serializationCtx))) 496 return failure(); 497 } 498 499 typeEnum = spirv::Opcode::OpTypePointer; 500 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); 501 operands.push_back(pointeeTypeID); 502 return success(); 503 } 504 505 if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) { 506 uint32_t elementTypeID = 0; 507 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), 508 elementTypeID, serializationCtx))) { 509 return failure(); 510 } 511 typeEnum = spirv::Opcode::OpTypeRuntimeArray; 512 operands.push_back(elementTypeID); 513 return processTypeDecoration(loc, runtimeArrayType, resultID); 514 } 515 516 if (auto sampledImageType = type.dyn_cast<spirv::SampledImageType>()) { 517 typeEnum = spirv::Opcode::OpTypeSampledImage; 518 uint32_t imageTypeID = 0; 519 if (failed( 520 processType(loc, sampledImageType.getImageType(), imageTypeID))) { 521 return failure(); 522 } 523 operands.push_back(imageTypeID); 524 return success(); 525 } 526 527 if (auto structType = type.dyn_cast<spirv::StructType>()) { 528 if (structType.isIdentified()) { 529 (void)processName(resultID, structType.getIdentifier()); 530 serializationCtx.insert(structType.getIdentifier()); 531 } 532 533 bool hasOffset = structType.hasOffset(); 534 for (auto elementIndex : 535 llvm::seq<uint32_t>(0, structType.getNumElements())) { 536 uint32_t elementTypeID = 0; 537 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), 538 elementTypeID, serializationCtx))) { 539 return failure(); 540 } 541 operands.push_back(elementTypeID); 542 if (hasOffset) { 543 // Decorate each struct member with an offset 544 spirv::StructType::MemberDecorationInfo offsetDecoration{ 545 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, 546 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; 547 if (failed(processMemberDecoration(resultID, offsetDecoration))) { 548 return emitError(loc, "cannot decorate ") 549 << elementIndex << "-th member of " << structType 550 << " with its offset"; 551 } 552 } 553 } 554 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 555 structType.getMemberDecorations(memberDecorations); 556 557 for (auto &memberDecoration : memberDecorations) { 558 if (failed(processMemberDecoration(resultID, memberDecoration))) { 559 return emitError(loc, "cannot decorate ") 560 << static_cast<uint32_t>(memberDecoration.memberIndex) 561 << "-th member of " << structType << " with " 562 << stringifyDecoration(memberDecoration.decoration); 563 } 564 } 565 566 typeEnum = spirv::Opcode::OpTypeStruct; 567 568 if (structType.isIdentified()) 569 serializationCtx.remove(structType.getIdentifier()); 570 571 return success(); 572 } 573 574 if (auto cooperativeMatrixType = 575 type.dyn_cast<spirv::CooperativeMatrixNVType>()) { 576 uint32_t elementTypeID = 0; 577 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), 578 elementTypeID, serializationCtx))) { 579 return failure(); 580 } 581 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; 582 auto getConstantOp = [&](uint32_t id) { 583 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); 584 return prepareConstantInt(loc, attr); 585 }; 586 operands.push_back(elementTypeID); 587 operands.push_back( 588 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope()))); 589 operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); 590 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); 591 return success(); 592 } 593 594 if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) { 595 uint32_t elementTypeID = 0; 596 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, 597 serializationCtx))) { 598 return failure(); 599 } 600 typeEnum = spirv::Opcode::OpTypeMatrix; 601 operands.push_back(elementTypeID); 602 operands.push_back(matrixType.getNumColumns()); 603 return success(); 604 } 605 606 // TODO: Handle other types. 607 return emitError(loc, "unhandled type in serialization: ") << type; 608 } 609 610 LogicalResult 611 Serializer::prepareFunctionType(Location loc, FunctionType type, 612 spirv::Opcode &typeEnum, 613 SmallVectorImpl<uint32_t> &operands) { 614 typeEnum = spirv::Opcode::OpTypeFunction; 615 assert(type.getNumResults() <= 1 && 616 "serialization supports only a single return value"); 617 uint32_t resultID = 0; 618 if (failed(processType( 619 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), 620 resultID))) { 621 return failure(); 622 } 623 operands.push_back(resultID); 624 for (auto &res : type.getInputs()) { 625 uint32_t argTypeID = 0; 626 if (failed(processType(loc, res, argTypeID))) { 627 return failure(); 628 } 629 operands.push_back(argTypeID); 630 } 631 return success(); 632 } 633 634 //===----------------------------------------------------------------------===// 635 // Constant 636 //===----------------------------------------------------------------------===// 637 638 uint32_t Serializer::prepareConstant(Location loc, Type constType, 639 Attribute valueAttr) { 640 if (auto id = prepareConstantScalar(loc, valueAttr)) { 641 return id; 642 } 643 644 // This is a composite literal. We need to handle each component separately 645 // and then emit an OpConstantComposite for the whole. 646 647 if (auto id = getConstantID(valueAttr)) { 648 return id; 649 } 650 651 uint32_t typeID = 0; 652 if (failed(processType(loc, constType, typeID))) { 653 return 0; 654 } 655 656 uint32_t resultID = 0; 657 if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) { 658 int rank = attr.getType().dyn_cast<ShapedType>().getRank(); 659 SmallVector<uint64_t, 4> index(rank); 660 resultID = prepareDenseElementsConstant(loc, constType, attr, 661 /*dim=*/0, index); 662 } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) { 663 resultID = prepareArrayConstant(loc, constType, arrayAttr); 664 } 665 666 if (resultID == 0) { 667 emitError(loc, "cannot serialize attribute: ") << valueAttr; 668 return 0; 669 } 670 671 constIDMap[valueAttr] = resultID; 672 return resultID; 673 } 674 675 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, 676 ArrayAttr attr) { 677 uint32_t typeID = 0; 678 if (failed(processType(loc, constType, typeID))) { 679 return 0; 680 } 681 682 uint32_t resultID = getNextID(); 683 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 684 operands.reserve(attr.size() + 2); 685 auto elementType = constType.cast<spirv::ArrayType>().getElementType(); 686 for (Attribute elementAttr : attr) { 687 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { 688 operands.push_back(elementID); 689 } else { 690 return 0; 691 } 692 } 693 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 694 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 695 696 return resultID; 697 } 698 699 // TODO: Turn the below function into iterative function, instead of 700 // recursive function. 701 uint32_t 702 Serializer::prepareDenseElementsConstant(Location loc, Type constType, 703 DenseElementsAttr valueAttr, int dim, 704 MutableArrayRef<uint64_t> index) { 705 auto shapedType = valueAttr.getType().dyn_cast<ShapedType>(); 706 assert(dim <= shapedType.getRank()); 707 if (shapedType.getRank() == dim) { 708 if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) { 709 return attr.getType().getElementType().isInteger(1) 710 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index]) 711 : prepareConstantInt(loc, 712 attr.getValues<IntegerAttr>()[index]); 713 } 714 if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) { 715 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]); 716 } 717 return 0; 718 } 719 720 uint32_t typeID = 0; 721 if (failed(processType(loc, constType, typeID))) { 722 return 0; 723 } 724 725 uint32_t resultID = getNextID(); 726 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 727 operands.reserve(shapedType.getDimSize(dim) + 2); 728 auto elementType = constType.cast<spirv::CompositeType>().getElementType(0); 729 for (int i = 0; i < shapedType.getDimSize(dim); ++i) { 730 index[dim] = i; 731 if (auto elementID = prepareDenseElementsConstant( 732 loc, elementType, valueAttr, dim + 1, index)) { 733 operands.push_back(elementID); 734 } else { 735 return 0; 736 } 737 } 738 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 739 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 740 741 return resultID; 742 } 743 744 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, 745 bool isSpec) { 746 if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) { 747 return prepareConstantFp(loc, floatAttr, isSpec); 748 } 749 if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) { 750 return prepareConstantBool(loc, boolAttr, isSpec); 751 } 752 if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) { 753 return prepareConstantInt(loc, intAttr, isSpec); 754 } 755 756 return 0; 757 } 758 759 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, 760 bool isSpec) { 761 if (!isSpec) { 762 // We can de-duplicate normal constants, but not specialization constants. 763 if (auto id = getConstantID(boolAttr)) { 764 return id; 765 } 766 } 767 768 // Process the type for this bool literal 769 uint32_t typeID = 0; 770 if (failed(processType(loc, boolAttr.getType(), typeID))) { 771 return 0; 772 } 773 774 auto resultID = getNextID(); 775 auto opcode = boolAttr.getValue() 776 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue 777 : spirv::Opcode::OpConstantTrue) 778 : (isSpec ? spirv::Opcode::OpSpecConstantFalse 779 : spirv::Opcode::OpConstantFalse); 780 (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); 781 782 if (!isSpec) { 783 constIDMap[boolAttr] = resultID; 784 } 785 return resultID; 786 } 787 788 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, 789 bool isSpec) { 790 if (!isSpec) { 791 // We can de-duplicate normal constants, but not specialization constants. 792 if (auto id = getConstantID(intAttr)) { 793 return id; 794 } 795 } 796 797 // Process the type for this integer literal 798 uint32_t typeID = 0; 799 if (failed(processType(loc, intAttr.getType(), typeID))) { 800 return 0; 801 } 802 803 auto resultID = getNextID(); 804 APInt value = intAttr.getValue(); 805 unsigned bitwidth = value.getBitWidth(); 806 bool isSigned = value.isSignedIntN(bitwidth); 807 808 auto opcode = 809 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 810 811 switch (bitwidth) { 812 // According to SPIR-V spec, "When the type's bit width is less than 813 // 32-bits, the literal's value appears in the low-order bits of the word, 814 // and the high-order bits must be 0 for a floating-point type, or 0 for an 815 // integer type with Signedness of 0, or sign extended when Signedness 816 // is 1." 817 case 32: 818 case 16: 819 case 8: { 820 uint32_t word = 0; 821 if (isSigned) { 822 word = static_cast<int32_t>(value.getSExtValue()); 823 } else { 824 word = static_cast<uint32_t>(value.getZExtValue()); 825 } 826 (void)encodeInstructionInto(typesGlobalValues, opcode, 827 {typeID, resultID, word}); 828 } break; 829 // According to SPIR-V spec: "When the type's bit width is larger than one 830 // word, the literal’s low-order words appear first." 831 case 64: { 832 struct DoubleWord { 833 uint32_t word1; 834 uint32_t word2; 835 } words; 836 if (isSigned) { 837 words = llvm::bit_cast<DoubleWord>(value.getSExtValue()); 838 } else { 839 words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); 840 } 841 (void)encodeInstructionInto(typesGlobalValues, opcode, 842 {typeID, resultID, words.word1, words.word2}); 843 } break; 844 default: { 845 std::string valueStr; 846 llvm::raw_string_ostream rss(valueStr); 847 value.print(rss, /*isSigned=*/false); 848 849 emitError(loc, "cannot serialize ") 850 << bitwidth << "-bit integer literal: " << rss.str(); 851 return 0; 852 } 853 } 854 855 if (!isSpec) { 856 constIDMap[intAttr] = resultID; 857 } 858 return resultID; 859 } 860 861 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, 862 bool isSpec) { 863 if (!isSpec) { 864 // We can de-duplicate normal constants, but not specialization constants. 865 if (auto id = getConstantID(floatAttr)) { 866 return id; 867 } 868 } 869 870 // Process the type for this float literal 871 uint32_t typeID = 0; 872 if (failed(processType(loc, floatAttr.getType(), typeID))) { 873 return 0; 874 } 875 876 auto resultID = getNextID(); 877 APFloat value = floatAttr.getValue(); 878 APInt intValue = value.bitcastToAPInt(); 879 880 auto opcode = 881 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 882 883 if (&value.getSemantics() == &APFloat::IEEEsingle()) { 884 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); 885 (void)encodeInstructionInto(typesGlobalValues, opcode, 886 {typeID, resultID, word}); 887 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { 888 struct DoubleWord { 889 uint32_t word1; 890 uint32_t word2; 891 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); 892 (void)encodeInstructionInto(typesGlobalValues, opcode, 893 {typeID, resultID, words.word1, words.word2}); 894 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { 895 uint32_t word = 896 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); 897 (void)encodeInstructionInto(typesGlobalValues, opcode, 898 {typeID, resultID, word}); 899 } else { 900 std::string valueStr; 901 llvm::raw_string_ostream rss(valueStr); 902 value.print(rss); 903 904 emitError(loc, "cannot serialize ") 905 << floatAttr.getType() << "-typed float literal: " << rss.str(); 906 return 0; 907 } 908 909 if (!isSpec) { 910 constIDMap[floatAttr] = resultID; 911 } 912 return resultID; 913 } 914 915 //===----------------------------------------------------------------------===// 916 // Control flow 917 //===----------------------------------------------------------------------===// 918 919 uint32_t Serializer::getOrCreateBlockID(Block *block) { 920 if (uint32_t id = getBlockID(block)) 921 return id; 922 return blockIDMap[block] = getNextID(); 923 } 924 925 LogicalResult 926 Serializer::processBlock(Block *block, bool omitLabel, 927 function_ref<void()> actionBeforeTerminator) { 928 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); 929 LLVM_DEBUG(block->print(llvm::dbgs())); 930 LLVM_DEBUG(llvm::dbgs() << '\n'); 931 if (!omitLabel) { 932 uint32_t blockID = getOrCreateBlockID(block); 933 LLVM_DEBUG(llvm::dbgs() 934 << "[block] " << block << " (id = " << blockID << ")\n"); 935 936 // Emit OpLabel for this block. 937 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, 938 {blockID}); 939 } 940 941 // Emit OpPhi instructions for block arguments, if any. 942 if (failed(emitPhiForBlockArguments(block))) 943 return failure(); 944 945 // Process each op in this block except the terminator. 946 for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { 947 if (failed(processOperation(&op))) 948 return failure(); 949 } 950 951 // Process the terminator. 952 if (actionBeforeTerminator) 953 actionBeforeTerminator(); 954 if (failed(processOperation(&block->back()))) 955 return failure(); 956 957 return success(); 958 } 959 960 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { 961 // Nothing to do if this block has no arguments or it's the entry block, which 962 // always has the same arguments as the function signature. 963 if (block->args_empty() || block->isEntryBlock()) 964 return success(); 965 966 // If the block has arguments, we need to create SPIR-V OpPhi instructions. 967 // A SPIR-V OpPhi instruction is of the syntax: 968 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair 969 // So we need to collect all predecessor blocks and the arguments they send 970 // to this block. 971 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors; 972 for (Block *predecessor : block->getPredecessors()) { 973 auto *terminator = predecessor->getTerminator(); 974 // The predecessor here is the immediate one according to MLIR's IR 975 // structure. It does not directly map to the incoming parent block for the 976 // OpPhi instructions at SPIR-V binary level. This is because structured 977 // control flow ops are serialized to multiple SPIR-V blocks. If there is a 978 // spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the 979 // branch op jumping to the OpPhi's block then resides in the previous 980 // structured control flow op's merge block. 981 predecessor = getPhiIncomingBlock(predecessor); 982 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { 983 predecessors.emplace_back(predecessor, branchOp.getOperands()); 984 } else if (auto branchCondOp = 985 dyn_cast<spirv::BranchConditionalOp>(terminator)) { 986 Optional<OperandRange> blockOperands; 987 988 for (auto successorIdx : 989 llvm::seq<unsigned>(0, predecessor->getNumSuccessors())) 990 if (predecessor->getSuccessors()[successorIdx] == block) { 991 blockOperands = branchCondOp.getSuccessorOperands(successorIdx); 992 break; 993 } 994 995 assert(blockOperands && !blockOperands->empty() && 996 "expected non-empty block operand range"); 997 predecessors.emplace_back(predecessor, *blockOperands); 998 } else { 999 return terminator->emitError("unimplemented terminator for Phi creation"); 1000 } 1001 } 1002 1003 // Then create OpPhi instruction for each of the block argument. 1004 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) { 1005 BlockArgument arg = block->getArgument(argIndex); 1006 1007 // Get the type <id> and result <id> for this OpPhi instruction. 1008 uint32_t phiTypeID = 0; 1009 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) 1010 return failure(); 1011 uint32_t phiID = getNextID(); 1012 1013 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' 1014 << arg << " (id = " << phiID << ")\n"); 1015 1016 // Prepare the (value <id>, parent block <id>) pairs. 1017 SmallVector<uint32_t, 8> phiArgs; 1018 phiArgs.push_back(phiTypeID); 1019 phiArgs.push_back(phiID); 1020 1021 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) { 1022 Value value = predecessors[predIndex].second[argIndex]; 1023 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); 1024 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId 1025 << ") value " << value << ' '); 1026 // Each pair is a value <id> ... 1027 uint32_t valueId = getValueID(value); 1028 if (valueId == 0) { 1029 // The op generating this value hasn't been visited yet so we don't have 1030 // an <id> assigned yet. Record this to fix up later. 1031 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); 1032 deferredPhiValues[value].push_back(functionBody.size() + 1 + 1033 phiArgs.size()); 1034 } else { 1035 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); 1036 } 1037 phiArgs.push_back(valueId); 1038 // ... and a parent block <id>. 1039 phiArgs.push_back(predBlockId); 1040 } 1041 1042 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); 1043 valueIDMap[arg] = phiID; 1044 } 1045 1046 return success(); 1047 } 1048 1049 //===----------------------------------------------------------------------===// 1050 // Operation 1051 //===----------------------------------------------------------------------===// 1052 1053 LogicalResult Serializer::encodeExtensionInstruction( 1054 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, 1055 ArrayRef<uint32_t> operands) { 1056 // Check if the extension has been imported. 1057 auto &setID = extendedInstSetIDMap[extensionSetName]; 1058 if (!setID) { 1059 setID = getNextID(); 1060 SmallVector<uint32_t, 16> importOperands; 1061 importOperands.push_back(setID); 1062 if (failed( 1063 spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || 1064 failed(encodeInstructionInto( 1065 extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { 1066 return failure(); 1067 } 1068 } 1069 1070 // The first two operands are the result type <id> and result <id>. The set 1071 // <id> and the opcode need to be insert after this. 1072 if (operands.size() < 2) { 1073 return op->emitError("extended instructions must have a result encoding"); 1074 } 1075 SmallVector<uint32_t, 8> extInstOperands; 1076 extInstOperands.reserve(operands.size() + 2); 1077 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); 1078 extInstOperands.push_back(setID); 1079 extInstOperands.push_back(extensionOpcode); 1080 extInstOperands.append(std::next(operands.begin(), 2), operands.end()); 1081 return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, 1082 extInstOperands); 1083 } 1084 1085 LogicalResult Serializer::processOperation(Operation *opInst) { 1086 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); 1087 1088 // First dispatch the ops that do not directly mirror an instruction from 1089 // the SPIR-V spec. 1090 return TypeSwitch<Operation *, LogicalResult>(opInst) 1091 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) 1092 .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) 1093 .Case([&](spirv::BranchConditionalOp op) { 1094 return processBranchConditionalOp(op); 1095 }) 1096 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) 1097 .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) 1098 .Case([&](spirv::GlobalVariableOp op) { 1099 return processGlobalVariableOp(op); 1100 }) 1101 .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) 1102 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) 1103 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) 1104 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) 1105 .Case([&](spirv::SpecConstantCompositeOp op) { 1106 return processSpecConstantCompositeOp(op); 1107 }) 1108 .Case([&](spirv::SpecConstantOperationOp op) { 1109 return processSpecConstantOperationOp(op); 1110 }) 1111 .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) 1112 .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) 1113 1114 // Then handle all the ops that directly mirror SPIR-V instructions with 1115 // auto-generated methods. 1116 .Default( 1117 [&](Operation *op) { return dispatchToAutogenSerialization(op); }); 1118 } 1119 1120 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, 1121 StringRef extInstSet, 1122 uint32_t opcode) { 1123 SmallVector<uint32_t, 4> operands; 1124 Location loc = op->getLoc(); 1125 1126 uint32_t resultID = 0; 1127 if (op->getNumResults() != 0) { 1128 uint32_t resultTypeID = 0; 1129 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) 1130 return failure(); 1131 operands.push_back(resultTypeID); 1132 1133 resultID = getNextID(); 1134 operands.push_back(resultID); 1135 valueIDMap[op->getResult(0)] = resultID; 1136 }; 1137 1138 for (Value operand : op->getOperands()) 1139 operands.push_back(getValueID(operand)); 1140 1141 (void)emitDebugLine(functionBody, loc); 1142 1143 if (extInstSet.empty()) { 1144 (void)encodeInstructionInto(functionBody, 1145 static_cast<spirv::Opcode>(opcode), operands); 1146 } else { 1147 (void)encodeExtensionInstruction(op, extInstSet, opcode, operands); 1148 } 1149 1150 if (op->getNumResults() != 0) { 1151 for (auto attr : op->getAttrs()) { 1152 if (failed(processDecoration(loc, resultID, attr))) 1153 return failure(); 1154 } 1155 } 1156 1157 return success(); 1158 } 1159 1160 LogicalResult Serializer::emitDecoration(uint32_t target, 1161 spirv::Decoration decoration, 1162 ArrayRef<uint32_t> params) { 1163 uint32_t wordCount = 3 + params.size(); 1164 decorations.push_back( 1165 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); 1166 decorations.push_back(target); 1167 decorations.push_back(static_cast<uint32_t>(decoration)); 1168 decorations.append(params.begin(), params.end()); 1169 return success(); 1170 } 1171 1172 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, 1173 Location loc) { 1174 if (!options.emitDebugInfo) 1175 return success(); 1176 1177 if (lastProcessedWasMergeInst) { 1178 lastProcessedWasMergeInst = false; 1179 return success(); 1180 } 1181 1182 auto fileLoc = loc.dyn_cast<FileLineColLoc>(); 1183 if (fileLoc) 1184 (void)encodeInstructionInto( 1185 binary, spirv::Opcode::OpLine, 1186 {fileID, fileLoc.getLine(), fileLoc.getColumn()}); 1187 return success(); 1188 } 1189 } // namespace spirv 1190 } // namespace mlir 1191