1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Serializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Support/LogicalResult.h" 19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/SmallPtrSet.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/ADT/bit.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "spirv-serialization" 28 29 using namespace mlir; 30 31 /// Returns the merge block if the given `op` is a structured control flow op. 32 /// Otherwise returns nullptr. 33 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { 34 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) 35 return selectionOp.getMergeBlock(); 36 if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) 37 return loopOp.getMergeBlock(); 38 return nullptr; 39 } 40 41 /// Given a predecessor `block` for a block with arguments, returns the block 42 /// that should be used as the parent block for SPIR-V OpPhi instructions 43 /// corresponding to the block arguments. 44 static Block *getPhiIncomingBlock(Block *block) { 45 // If the predecessor block in question is the entry block for a 46 // spv.mlir.loop, we jump to this spv.mlir.loop from its enclosing block. 47 if (block->isEntryBlock()) { 48 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { 49 // Then the incoming parent block for OpPhi should be the merge block of 50 // the structured control flow op before this loop. 51 Operation *op = loopOp.getOperation(); 52 while ((op = op->getPrevNode()) != nullptr) 53 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) 54 return incomingBlock; 55 // Or the enclosing block itself if no structured control flow ops 56 // exists before this loop. 57 return loopOp->getBlock(); 58 } 59 } 60 61 // Otherwise, we jump from the given predecessor block. Try to see if there is 62 // a structured control flow op inside it. 63 for (Operation &op : llvm::reverse(block->getOperations())) { 64 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) 65 return incomingBlock; 66 } 67 return block; 68 } 69 70 namespace mlir { 71 namespace spirv { 72 73 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into 74 /// the given `binary` vector. 75 LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, 76 spirv::Opcode op, 77 ArrayRef<uint32_t> operands) { 78 uint32_t wordCount = 1 + operands.size(); 79 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 80 binary.append(operands.begin(), operands.end()); 81 return success(); 82 } 83 84 Serializer::Serializer(spirv::ModuleOp module, bool emitDebugInfo) 85 : module(module), mlirBuilder(module.getContext()), 86 emitDebugInfo(emitDebugInfo) {} 87 88 LogicalResult Serializer::serialize() { 89 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); 90 91 if (failed(module.verify())) 92 return failure(); 93 94 // TODO: handle the other sections 95 processCapability(); 96 processExtension(); 97 processMemoryModel(); 98 processDebugInfo(); 99 100 // Iterate over the module body to serialize it. Assumptions are that there is 101 // only one basic block in the moduleOp 102 for (auto &op : *module.getBody()) { 103 if (failed(processOperation(&op))) { 104 return failure(); 105 } 106 } 107 108 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); 109 return success(); 110 } 111 112 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { 113 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + 114 extensions.size() + extendedSets.size() + 115 memoryModel.size() + entryPoints.size() + 116 executionModes.size() + decorations.size() + 117 typesGlobalValues.size() + functions.size(); 118 119 binary.clear(); 120 binary.reserve(moduleSize); 121 122 spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); 123 binary.append(capabilities.begin(), capabilities.end()); 124 binary.append(extensions.begin(), extensions.end()); 125 binary.append(extendedSets.begin(), extendedSets.end()); 126 binary.append(memoryModel.begin(), memoryModel.end()); 127 binary.append(entryPoints.begin(), entryPoints.end()); 128 binary.append(executionModes.begin(), executionModes.end()); 129 binary.append(debug.begin(), debug.end()); 130 binary.append(names.begin(), names.end()); 131 binary.append(decorations.begin(), decorations.end()); 132 binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); 133 binary.append(functions.begin(), functions.end()); 134 } 135 136 #ifndef NDEBUG 137 void Serializer::printValueIDMap(raw_ostream &os) { 138 os << "\n= Value <id> Map =\n\n"; 139 for (auto valueIDPair : valueIDMap) { 140 Value val = valueIDPair.first; 141 os << " " << val << " " 142 << "id = " << valueIDPair.second << ' '; 143 if (auto *op = val.getDefiningOp()) { 144 os << "from op '" << op->getName() << "'"; 145 } else if (auto arg = val.dyn_cast<BlockArgument>()) { 146 Block *block = arg.getOwner(); 147 os << "from argument of block " << block << ' '; 148 os << " in op '" << block->getParentOp()->getName() << "'"; 149 } 150 os << '\n'; 151 } 152 } 153 #endif 154 155 //===----------------------------------------------------------------------===// 156 // Module structure 157 //===----------------------------------------------------------------------===// 158 159 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { 160 auto funcID = funcIDMap.lookup(fnName); 161 if (!funcID) { 162 funcID = getNextID(); 163 funcIDMap[fnName] = funcID; 164 } 165 return funcID; 166 } 167 168 void Serializer::processCapability() { 169 for (auto cap : module.vce_triple()->getCapabilities()) 170 (void)encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, 171 {static_cast<uint32_t>(cap)}); 172 } 173 174 void Serializer::processDebugInfo() { 175 if (!emitDebugInfo) 176 return; 177 auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>(); 178 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>"; 179 fileID = getNextID(); 180 SmallVector<uint32_t, 16> operands; 181 operands.push_back(fileID); 182 (void)spirv::encodeStringLiteralInto(operands, fileName); 183 (void)encodeInstructionInto(debug, spirv::Opcode::OpString, operands); 184 // TODO: Encode more debug instructions. 185 } 186 187 void Serializer::processExtension() { 188 llvm::SmallVector<uint32_t, 16> extName; 189 for (spirv::Extension ext : module.vce_triple()->getExtensions()) { 190 extName.clear(); 191 (void)spirv::encodeStringLiteralInto(extName, 192 spirv::stringifyExtension(ext)); 193 (void)encodeInstructionInto(extensions, spirv::Opcode::OpExtension, 194 extName); 195 } 196 } 197 198 void Serializer::processMemoryModel() { 199 uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt(); 200 uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt(); 201 202 (void)encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, 203 {am, mm}); 204 } 205 206 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, 207 NamedAttribute attr) { 208 auto attrName = attr.first.strref(); 209 auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); 210 auto decoration = spirv::symbolizeDecoration(decorationName); 211 if (!decoration) { 212 return emitError( 213 loc, "non-argument attributes expected to have snake-case-ified " 214 "decoration name, unhandled attribute with name : ") 215 << attrName; 216 } 217 SmallVector<uint32_t, 1> args; 218 switch (decoration.getValue()) { 219 case spirv::Decoration::Binding: 220 case spirv::Decoration::DescriptorSet: 221 case spirv::Decoration::Location: 222 if (auto intAttr = attr.second.dyn_cast<IntegerAttr>()) { 223 args.push_back(intAttr.getValue().getZExtValue()); 224 break; 225 } 226 return emitError(loc, "expected integer attribute for ") << attrName; 227 case spirv::Decoration::BuiltIn: 228 if (auto strAttr = attr.second.dyn_cast<StringAttr>()) { 229 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); 230 if (enumVal) { 231 args.push_back(static_cast<uint32_t>(enumVal.getValue())); 232 break; 233 } 234 return emitError(loc, "invalid ") 235 << attrName << " attribute " << strAttr.getValue(); 236 } 237 return emitError(loc, "expected string attribute for ") << attrName; 238 case spirv::Decoration::Aliased: 239 case spirv::Decoration::Flat: 240 case spirv::Decoration::NonReadable: 241 case spirv::Decoration::NonWritable: 242 case spirv::Decoration::NoPerspective: 243 case spirv::Decoration::Restrict: 244 case spirv::Decoration::RelaxedPrecision: 245 // For unit attributes, the args list has no values so we do nothing 246 if (auto unitAttr = attr.second.dyn_cast<UnitAttr>()) 247 break; 248 return emitError(loc, "expected unit attribute for ") << attrName; 249 default: 250 return emitError(loc, "unhandled decoration ") << decorationName; 251 } 252 return emitDecoration(resultID, decoration.getValue(), args); 253 } 254 255 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { 256 assert(!name.empty() && "unexpected empty string for OpName"); 257 258 SmallVector<uint32_t, 4> nameOperands; 259 nameOperands.push_back(resultID); 260 if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) { 261 return failure(); 262 } 263 return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); 264 } 265 266 template <> 267 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( 268 Location loc, spirv::ArrayType type, uint32_t resultID) { 269 if (unsigned stride = type.getArrayStride()) { 270 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 271 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 272 } 273 return success(); 274 } 275 276 template <> 277 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( 278 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) { 279 if (unsigned stride = type.getArrayStride()) { 280 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 281 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 282 } 283 return success(); 284 } 285 286 LogicalResult Serializer::processMemberDecoration( 287 uint32_t structID, 288 const spirv::StructType::MemberDecorationInfo &memberDecoration) { 289 SmallVector<uint32_t, 4> args( 290 {structID, memberDecoration.memberIndex, 291 static_cast<uint32_t>(memberDecoration.decoration)}); 292 if (memberDecoration.hasValue) { 293 args.push_back(memberDecoration.decorationValue); 294 } 295 return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, 296 args); 297 } 298 299 //===----------------------------------------------------------------------===// 300 // Type 301 //===----------------------------------------------------------------------===// 302 303 // According to the SPIR-V spec "Validation Rules for Shader Capabilities": 304 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and 305 // PushConstant Storage Classes must be explicitly laid out." 306 bool Serializer::isInterfaceStructPtrType(Type type) const { 307 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 308 switch (ptrType.getStorageClass()) { 309 case spirv::StorageClass::PhysicalStorageBuffer: 310 case spirv::StorageClass::PushConstant: 311 case spirv::StorageClass::StorageBuffer: 312 case spirv::StorageClass::Uniform: 313 return ptrType.getPointeeType().isa<spirv::StructType>(); 314 default: 315 break; 316 } 317 } 318 return false; 319 } 320 321 LogicalResult Serializer::processType(Location loc, Type type, 322 uint32_t &typeID) { 323 // Maintains a set of names for nested identified struct types. This is used 324 // to properly serialize recursive references. 325 SetVector<StringRef> serializationCtx; 326 return processTypeImpl(loc, type, typeID, serializationCtx); 327 } 328 329 LogicalResult 330 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, 331 SetVector<StringRef> &serializationCtx) { 332 typeID = getTypeID(type); 333 if (typeID) { 334 return success(); 335 } 336 typeID = getNextID(); 337 SmallVector<uint32_t, 4> operands; 338 339 operands.push_back(typeID); 340 auto typeEnum = spirv::Opcode::OpTypeVoid; 341 bool deferSerialization = false; 342 343 if ((type.isa<FunctionType>() && 344 succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum, 345 operands))) || 346 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, 347 deferSerialization, serializationCtx))) { 348 if (deferSerialization) 349 return success(); 350 351 typeIDMap[type] = typeID; 352 353 if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) 354 return failure(); 355 356 if (recursiveStructInfos.count(type) != 0) { 357 // This recursive struct type is emitted already, now the OpTypePointer 358 // instructions referring to recursive references are emitted as well. 359 for (auto &ptrInfo : recursiveStructInfos[type]) { 360 // TODO: This might not work if more than 1 recursive reference is 361 // present in the struct. 362 SmallVector<uint32_t, 4> ptrOperands; 363 ptrOperands.push_back(ptrInfo.pointerTypeID); 364 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass)); 365 ptrOperands.push_back(typeIDMap[type]); 366 367 if (failed(encodeInstructionInto( 368 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) 369 return failure(); 370 } 371 372 recursiveStructInfos[type].clear(); 373 } 374 375 return success(); 376 } 377 378 return failure(); 379 } 380 381 LogicalResult Serializer::prepareBasicType( 382 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, 383 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, 384 SetVector<StringRef> &serializationCtx) { 385 deferSerialization = false; 386 387 if (isVoidType(type)) { 388 typeEnum = spirv::Opcode::OpTypeVoid; 389 return success(); 390 } 391 392 if (auto intType = type.dyn_cast<IntegerType>()) { 393 if (intType.getWidth() == 1) { 394 typeEnum = spirv::Opcode::OpTypeBool; 395 return success(); 396 } 397 398 typeEnum = spirv::Opcode::OpTypeInt; 399 operands.push_back(intType.getWidth()); 400 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 401 // to preserve or validate. 402 // 0 indicates unsigned, or no signedness semantics 403 // 1 indicates signed semantics." 404 operands.push_back(intType.isSigned() ? 1 : 0); 405 return success(); 406 } 407 408 if (auto floatType = type.dyn_cast<FloatType>()) { 409 typeEnum = spirv::Opcode::OpTypeFloat; 410 operands.push_back(floatType.getWidth()); 411 return success(); 412 } 413 414 if (auto vectorType = type.dyn_cast<VectorType>()) { 415 uint32_t elementTypeID = 0; 416 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, 417 serializationCtx))) { 418 return failure(); 419 } 420 typeEnum = spirv::Opcode::OpTypeVector; 421 operands.push_back(elementTypeID); 422 operands.push_back(vectorType.getNumElements()); 423 return success(); 424 } 425 426 if (auto imageType = type.dyn_cast<spirv::ImageType>()) { 427 typeEnum = spirv::Opcode::OpTypeImage; 428 uint32_t sampledTypeID = 0; 429 if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) 430 return failure(); 431 432 operands.push_back(sampledTypeID); 433 operands.push_back(static_cast<uint32_t>(imageType.getDim())); 434 operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo())); 435 operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo())); 436 operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo())); 437 operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo())); 438 operands.push_back(static_cast<uint32_t>(imageType.getImageFormat())); 439 return success(); 440 } 441 442 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) { 443 typeEnum = spirv::Opcode::OpTypeArray; 444 uint32_t elementTypeID = 0; 445 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, 446 serializationCtx))) { 447 return failure(); 448 } 449 operands.push_back(elementTypeID); 450 if (auto elementCountID = prepareConstantInt( 451 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { 452 operands.push_back(elementCountID); 453 } 454 return processTypeDecoration(loc, arrayType, resultID); 455 } 456 457 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 458 uint32_t pointeeTypeID = 0; 459 spirv::StructType pointeeStruct = 460 ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 461 462 if (pointeeStruct && pointeeStruct.isIdentified() && 463 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { 464 // A recursive reference to an enclosing struct is found. 465 // 466 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage 467 // class as operands. 468 SmallVector<uint32_t, 2> forwardPtrOperands; 469 forwardPtrOperands.push_back(resultID); 470 forwardPtrOperands.push_back( 471 static_cast<uint32_t>(ptrType.getStorageClass())); 472 473 (void)encodeInstructionInto(typesGlobalValues, 474 spirv::Opcode::OpTypeForwardPointer, 475 forwardPtrOperands); 476 477 // 2. Find the pointee (enclosing) struct. 478 auto structType = spirv::StructType::getIdentified( 479 module.getContext(), pointeeStruct.getIdentifier()); 480 481 if (!structType) 482 return failure(); 483 484 // 3. Mark the OpTypePointer that is supposed to be emitted by this call 485 // as deferred. 486 deferSerialization = true; 487 488 // 4. Record the info needed to emit the deferred OpTypePointer 489 // instruction when the enclosing struct is completely serialized. 490 recursiveStructInfos[structType].push_back( 491 {resultID, ptrType.getStorageClass()}); 492 } else { 493 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, 494 serializationCtx))) 495 return failure(); 496 } 497 498 typeEnum = spirv::Opcode::OpTypePointer; 499 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); 500 operands.push_back(pointeeTypeID); 501 return success(); 502 } 503 504 if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) { 505 uint32_t elementTypeID = 0; 506 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), 507 elementTypeID, serializationCtx))) { 508 return failure(); 509 } 510 typeEnum = spirv::Opcode::OpTypeRuntimeArray; 511 operands.push_back(elementTypeID); 512 return processTypeDecoration(loc, runtimeArrayType, resultID); 513 } 514 515 if (auto sampledImageType = type.dyn_cast<spirv::SampledImageType>()) { 516 typeEnum = spirv::Opcode::OpTypeSampledImage; 517 uint32_t imageTypeID = 0; 518 if (failed( 519 processType(loc, sampledImageType.getImageType(), imageTypeID))) { 520 return failure(); 521 } 522 operands.push_back(imageTypeID); 523 return success(); 524 } 525 526 if (auto structType = type.dyn_cast<spirv::StructType>()) { 527 if (structType.isIdentified()) { 528 (void)processName(resultID, structType.getIdentifier()); 529 serializationCtx.insert(structType.getIdentifier()); 530 } 531 532 bool hasOffset = structType.hasOffset(); 533 for (auto elementIndex : 534 llvm::seq<uint32_t>(0, structType.getNumElements())) { 535 uint32_t elementTypeID = 0; 536 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), 537 elementTypeID, serializationCtx))) { 538 return failure(); 539 } 540 operands.push_back(elementTypeID); 541 if (hasOffset) { 542 // Decorate each struct member with an offset 543 spirv::StructType::MemberDecorationInfo offsetDecoration{ 544 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, 545 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; 546 if (failed(processMemberDecoration(resultID, offsetDecoration))) { 547 return emitError(loc, "cannot decorate ") 548 << elementIndex << "-th member of " << structType 549 << " with its offset"; 550 } 551 } 552 } 553 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 554 structType.getMemberDecorations(memberDecorations); 555 556 for (auto &memberDecoration : memberDecorations) { 557 if (failed(processMemberDecoration(resultID, memberDecoration))) { 558 return emitError(loc, "cannot decorate ") 559 << static_cast<uint32_t>(memberDecoration.memberIndex) 560 << "-th member of " << structType << " with " 561 << stringifyDecoration(memberDecoration.decoration); 562 } 563 } 564 565 typeEnum = spirv::Opcode::OpTypeStruct; 566 567 if (structType.isIdentified()) 568 serializationCtx.remove(structType.getIdentifier()); 569 570 return success(); 571 } 572 573 if (auto cooperativeMatrixType = 574 type.dyn_cast<spirv::CooperativeMatrixNVType>()) { 575 uint32_t elementTypeID = 0; 576 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), 577 elementTypeID, serializationCtx))) { 578 return failure(); 579 } 580 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; 581 auto getConstantOp = [&](uint32_t id) { 582 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); 583 return prepareConstantInt(loc, attr); 584 }; 585 operands.push_back(elementTypeID); 586 operands.push_back( 587 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope()))); 588 operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); 589 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); 590 return success(); 591 } 592 593 if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) { 594 uint32_t elementTypeID = 0; 595 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, 596 serializationCtx))) { 597 return failure(); 598 } 599 typeEnum = spirv::Opcode::OpTypeMatrix; 600 operands.push_back(elementTypeID); 601 operands.push_back(matrixType.getNumColumns()); 602 return success(); 603 } 604 605 // TODO: Handle other types. 606 return emitError(loc, "unhandled type in serialization: ") << type; 607 } 608 609 LogicalResult 610 Serializer::prepareFunctionType(Location loc, FunctionType type, 611 spirv::Opcode &typeEnum, 612 SmallVectorImpl<uint32_t> &operands) { 613 typeEnum = spirv::Opcode::OpTypeFunction; 614 assert(type.getNumResults() <= 1 && 615 "serialization supports only a single return value"); 616 uint32_t resultID = 0; 617 if (failed(processType( 618 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), 619 resultID))) { 620 return failure(); 621 } 622 operands.push_back(resultID); 623 for (auto &res : type.getInputs()) { 624 uint32_t argTypeID = 0; 625 if (failed(processType(loc, res, argTypeID))) { 626 return failure(); 627 } 628 operands.push_back(argTypeID); 629 } 630 return success(); 631 } 632 633 //===----------------------------------------------------------------------===// 634 // Constant 635 //===----------------------------------------------------------------------===// 636 637 uint32_t Serializer::prepareConstant(Location loc, Type constType, 638 Attribute valueAttr) { 639 if (auto id = prepareConstantScalar(loc, valueAttr)) { 640 return id; 641 } 642 643 // This is a composite literal. We need to handle each component separately 644 // and then emit an OpConstantComposite for the whole. 645 646 if (auto id = getConstantID(valueAttr)) { 647 return id; 648 } 649 650 uint32_t typeID = 0; 651 if (failed(processType(loc, constType, typeID))) { 652 return 0; 653 } 654 655 uint32_t resultID = 0; 656 if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) { 657 int rank = attr.getType().dyn_cast<ShapedType>().getRank(); 658 SmallVector<uint64_t, 4> index(rank); 659 resultID = prepareDenseElementsConstant(loc, constType, attr, 660 /*dim=*/0, index); 661 } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) { 662 resultID = prepareArrayConstant(loc, constType, arrayAttr); 663 } 664 665 if (resultID == 0) { 666 emitError(loc, "cannot serialize attribute: ") << valueAttr; 667 return 0; 668 } 669 670 constIDMap[valueAttr] = resultID; 671 return resultID; 672 } 673 674 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, 675 ArrayAttr attr) { 676 uint32_t typeID = 0; 677 if (failed(processType(loc, constType, typeID))) { 678 return 0; 679 } 680 681 uint32_t resultID = getNextID(); 682 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 683 operands.reserve(attr.size() + 2); 684 auto elementType = constType.cast<spirv::ArrayType>().getElementType(); 685 for (Attribute elementAttr : attr) { 686 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { 687 operands.push_back(elementID); 688 } else { 689 return 0; 690 } 691 } 692 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 693 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 694 695 return resultID; 696 } 697 698 // TODO: Turn the below function into iterative function, instead of 699 // recursive function. 700 uint32_t 701 Serializer::prepareDenseElementsConstant(Location loc, Type constType, 702 DenseElementsAttr valueAttr, int dim, 703 MutableArrayRef<uint64_t> index) { 704 auto shapedType = valueAttr.getType().dyn_cast<ShapedType>(); 705 assert(dim <= shapedType.getRank()); 706 if (shapedType.getRank() == dim) { 707 if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) { 708 return attr.getType().getElementType().isInteger(1) 709 ? prepareConstantBool(loc, attr.getValue<BoolAttr>(index)) 710 : prepareConstantInt(loc, attr.getValue<IntegerAttr>(index)); 711 } 712 if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) { 713 return prepareConstantFp(loc, attr.getValue<FloatAttr>(index)); 714 } 715 return 0; 716 } 717 718 uint32_t typeID = 0; 719 if (failed(processType(loc, constType, typeID))) { 720 return 0; 721 } 722 723 uint32_t resultID = getNextID(); 724 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 725 operands.reserve(shapedType.getDimSize(dim) + 2); 726 auto elementType = constType.cast<spirv::CompositeType>().getElementType(0); 727 for (int i = 0; i < shapedType.getDimSize(dim); ++i) { 728 index[dim] = i; 729 if (auto elementID = prepareDenseElementsConstant( 730 loc, elementType, valueAttr, dim + 1, index)) { 731 operands.push_back(elementID); 732 } else { 733 return 0; 734 } 735 } 736 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 737 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 738 739 return resultID; 740 } 741 742 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, 743 bool isSpec) { 744 if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) { 745 return prepareConstantFp(loc, floatAttr, isSpec); 746 } 747 if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) { 748 return prepareConstantBool(loc, boolAttr, isSpec); 749 } 750 if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) { 751 return prepareConstantInt(loc, intAttr, isSpec); 752 } 753 754 return 0; 755 } 756 757 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, 758 bool isSpec) { 759 if (!isSpec) { 760 // We can de-duplicate normal constants, but not specialization constants. 761 if (auto id = getConstantID(boolAttr)) { 762 return id; 763 } 764 } 765 766 // Process the type for this bool literal 767 uint32_t typeID = 0; 768 if (failed(processType(loc, boolAttr.getType(), typeID))) { 769 return 0; 770 } 771 772 auto resultID = getNextID(); 773 auto opcode = boolAttr.getValue() 774 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue 775 : spirv::Opcode::OpConstantTrue) 776 : (isSpec ? spirv::Opcode::OpSpecConstantFalse 777 : spirv::Opcode::OpConstantFalse); 778 (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); 779 780 if (!isSpec) { 781 constIDMap[boolAttr] = resultID; 782 } 783 return resultID; 784 } 785 786 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, 787 bool isSpec) { 788 if (!isSpec) { 789 // We can de-duplicate normal constants, but not specialization constants. 790 if (auto id = getConstantID(intAttr)) { 791 return id; 792 } 793 } 794 795 // Process the type for this integer literal 796 uint32_t typeID = 0; 797 if (failed(processType(loc, intAttr.getType(), typeID))) { 798 return 0; 799 } 800 801 auto resultID = getNextID(); 802 APInt value = intAttr.getValue(); 803 unsigned bitwidth = value.getBitWidth(); 804 bool isSigned = value.isSignedIntN(bitwidth); 805 806 auto opcode = 807 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 808 809 switch (bitwidth) { 810 // According to SPIR-V spec, "When the type's bit width is less than 811 // 32-bits, the literal's value appears in the low-order bits of the word, 812 // and the high-order bits must be 0 for a floating-point type, or 0 for an 813 // integer type with Signedness of 0, or sign extended when Signedness 814 // is 1." 815 case 32: 816 case 16: 817 case 8: { 818 uint32_t word = 0; 819 if (isSigned) { 820 word = static_cast<int32_t>(value.getSExtValue()); 821 } else { 822 word = static_cast<uint32_t>(value.getZExtValue()); 823 } 824 (void)encodeInstructionInto(typesGlobalValues, opcode, 825 {typeID, resultID, word}); 826 } break; 827 // According to SPIR-V spec: "When the type's bit width is larger than one 828 // word, the literal’s low-order words appear first." 829 case 64: { 830 struct DoubleWord { 831 uint32_t word1; 832 uint32_t word2; 833 } words; 834 if (isSigned) { 835 words = llvm::bit_cast<DoubleWord>(value.getSExtValue()); 836 } else { 837 words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); 838 } 839 (void)encodeInstructionInto(typesGlobalValues, opcode, 840 {typeID, resultID, words.word1, words.word2}); 841 } break; 842 default: { 843 std::string valueStr; 844 llvm::raw_string_ostream rss(valueStr); 845 value.print(rss, /*isSigned=*/false); 846 847 emitError(loc, "cannot serialize ") 848 << bitwidth << "-bit integer literal: " << rss.str(); 849 return 0; 850 } 851 } 852 853 if (!isSpec) { 854 constIDMap[intAttr] = resultID; 855 } 856 return resultID; 857 } 858 859 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, 860 bool isSpec) { 861 if (!isSpec) { 862 // We can de-duplicate normal constants, but not specialization constants. 863 if (auto id = getConstantID(floatAttr)) { 864 return id; 865 } 866 } 867 868 // Process the type for this float literal 869 uint32_t typeID = 0; 870 if (failed(processType(loc, floatAttr.getType(), typeID))) { 871 return 0; 872 } 873 874 auto resultID = getNextID(); 875 APFloat value = floatAttr.getValue(); 876 APInt intValue = value.bitcastToAPInt(); 877 878 auto opcode = 879 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 880 881 if (&value.getSemantics() == &APFloat::IEEEsingle()) { 882 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); 883 (void)encodeInstructionInto(typesGlobalValues, opcode, 884 {typeID, resultID, word}); 885 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { 886 struct DoubleWord { 887 uint32_t word1; 888 uint32_t word2; 889 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); 890 (void)encodeInstructionInto(typesGlobalValues, opcode, 891 {typeID, resultID, words.word1, words.word2}); 892 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { 893 uint32_t word = 894 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); 895 (void)encodeInstructionInto(typesGlobalValues, opcode, 896 {typeID, resultID, word}); 897 } else { 898 std::string valueStr; 899 llvm::raw_string_ostream rss(valueStr); 900 value.print(rss); 901 902 emitError(loc, "cannot serialize ") 903 << floatAttr.getType() << "-typed float literal: " << rss.str(); 904 return 0; 905 } 906 907 if (!isSpec) { 908 constIDMap[floatAttr] = resultID; 909 } 910 return resultID; 911 } 912 913 //===----------------------------------------------------------------------===// 914 // Control flow 915 //===----------------------------------------------------------------------===// 916 917 uint32_t Serializer::getOrCreateBlockID(Block *block) { 918 if (uint32_t id = getBlockID(block)) 919 return id; 920 return blockIDMap[block] = getNextID(); 921 } 922 923 LogicalResult 924 Serializer::processBlock(Block *block, bool omitLabel, 925 function_ref<void()> actionBeforeTerminator) { 926 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); 927 LLVM_DEBUG(block->print(llvm::dbgs())); 928 LLVM_DEBUG(llvm::dbgs() << '\n'); 929 if (!omitLabel) { 930 uint32_t blockID = getOrCreateBlockID(block); 931 LLVM_DEBUG(llvm::dbgs() 932 << "[block] " << block << " (id = " << blockID << ")\n"); 933 934 // Emit OpLabel for this block. 935 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, 936 {blockID}); 937 } 938 939 // Emit OpPhi instructions for block arguments, if any. 940 if (failed(emitPhiForBlockArguments(block))) 941 return failure(); 942 943 // Process each op in this block except the terminator. 944 for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { 945 if (failed(processOperation(&op))) 946 return failure(); 947 } 948 949 // Process the terminator. 950 if (actionBeforeTerminator) 951 actionBeforeTerminator(); 952 if (failed(processOperation(&block->back()))) 953 return failure(); 954 955 return success(); 956 } 957 958 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { 959 // Nothing to do if this block has no arguments or it's the entry block, which 960 // always has the same arguments as the function signature. 961 if (block->args_empty() || block->isEntryBlock()) 962 return success(); 963 964 // If the block has arguments, we need to create SPIR-V OpPhi instructions. 965 // A SPIR-V OpPhi instruction is of the syntax: 966 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair 967 // So we need to collect all predecessor blocks and the arguments they send 968 // to this block. 969 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors; 970 for (Block *predecessor : block->getPredecessors()) { 971 auto *terminator = predecessor->getTerminator(); 972 // The predecessor here is the immediate one according to MLIR's IR 973 // structure. It does not directly map to the incoming parent block for the 974 // OpPhi instructions at SPIR-V binary level. This is because structured 975 // control flow ops are serialized to multiple SPIR-V blocks. If there is a 976 // spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the 977 // branch op jumping to the OpPhi's block then resides in the previous 978 // structured control flow op's merge block. 979 predecessor = getPhiIncomingBlock(predecessor); 980 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { 981 predecessors.emplace_back(predecessor, branchOp.getOperands()); 982 } else if (auto branchCondOp = 983 dyn_cast<spirv::BranchConditionalOp>(terminator)) { 984 Optional<OperandRange> blockOperands; 985 986 for (auto successorIdx : 987 llvm::seq<unsigned>(0, predecessor->getNumSuccessors())) 988 if (predecessor->getSuccessors()[successorIdx] == block) { 989 blockOperands = branchCondOp.getSuccessorOperands(successorIdx); 990 break; 991 } 992 993 assert(blockOperands && !blockOperands->empty() && 994 "expected non-empty block operand range"); 995 predecessors.emplace_back(predecessor, *blockOperands); 996 } else { 997 return terminator->emitError("unimplemented terminator for Phi creation"); 998 } 999 } 1000 1001 // Then create OpPhi instruction for each of the block argument. 1002 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) { 1003 BlockArgument arg = block->getArgument(argIndex); 1004 1005 // Get the type <id> and result <id> for this OpPhi instruction. 1006 uint32_t phiTypeID = 0; 1007 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) 1008 return failure(); 1009 uint32_t phiID = getNextID(); 1010 1011 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' 1012 << arg << " (id = " << phiID << ")\n"); 1013 1014 // Prepare the (value <id>, parent block <id>) pairs. 1015 SmallVector<uint32_t, 8> phiArgs; 1016 phiArgs.push_back(phiTypeID); 1017 phiArgs.push_back(phiID); 1018 1019 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) { 1020 Value value = predecessors[predIndex].second[argIndex]; 1021 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); 1022 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId 1023 << ") value " << value << ' '); 1024 // Each pair is a value <id> ... 1025 uint32_t valueId = getValueID(value); 1026 if (valueId == 0) { 1027 // The op generating this value hasn't been visited yet so we don't have 1028 // an <id> assigned yet. Record this to fix up later. 1029 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); 1030 deferredPhiValues[value].push_back(functionBody.size() + 1 + 1031 phiArgs.size()); 1032 } else { 1033 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); 1034 } 1035 phiArgs.push_back(valueId); 1036 // ... and a parent block <id>. 1037 phiArgs.push_back(predBlockId); 1038 } 1039 1040 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); 1041 valueIDMap[arg] = phiID; 1042 } 1043 1044 return success(); 1045 } 1046 1047 //===----------------------------------------------------------------------===// 1048 // Operation 1049 //===----------------------------------------------------------------------===// 1050 1051 LogicalResult Serializer::encodeExtensionInstruction( 1052 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, 1053 ArrayRef<uint32_t> operands) { 1054 // Check if the extension has been imported. 1055 auto &setID = extendedInstSetIDMap[extensionSetName]; 1056 if (!setID) { 1057 setID = getNextID(); 1058 SmallVector<uint32_t, 16> importOperands; 1059 importOperands.push_back(setID); 1060 if (failed( 1061 spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || 1062 failed(encodeInstructionInto( 1063 extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { 1064 return failure(); 1065 } 1066 } 1067 1068 // The first two operands are the result type <id> and result <id>. The set 1069 // <id> and the opcode need to be insert after this. 1070 if (operands.size() < 2) { 1071 return op->emitError("extended instructions must have a result encoding"); 1072 } 1073 SmallVector<uint32_t, 8> extInstOperands; 1074 extInstOperands.reserve(operands.size() + 2); 1075 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); 1076 extInstOperands.push_back(setID); 1077 extInstOperands.push_back(extensionOpcode); 1078 extInstOperands.append(std::next(operands.begin(), 2), operands.end()); 1079 return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, 1080 extInstOperands); 1081 } 1082 1083 LogicalResult Serializer::processOperation(Operation *opInst) { 1084 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); 1085 1086 // First dispatch the ops that do not directly mirror an instruction from 1087 // the SPIR-V spec. 1088 return TypeSwitch<Operation *, LogicalResult>(opInst) 1089 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) 1090 .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) 1091 .Case([&](spirv::BranchConditionalOp op) { 1092 return processBranchConditionalOp(op); 1093 }) 1094 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) 1095 .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) 1096 .Case([&](spirv::GlobalVariableOp op) { 1097 return processGlobalVariableOp(op); 1098 }) 1099 .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) 1100 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) 1101 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) 1102 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) 1103 .Case([&](spirv::SpecConstantCompositeOp op) { 1104 return processSpecConstantCompositeOp(op); 1105 }) 1106 .Case([&](spirv::SpecConstantOperationOp op) { 1107 return processSpecConstantOperationOp(op); 1108 }) 1109 .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) 1110 .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) 1111 1112 // Then handle all the ops that directly mirror SPIR-V instructions with 1113 // auto-generated methods. 1114 .Default( 1115 [&](Operation *op) { return dispatchToAutogenSerialization(op); }); 1116 } 1117 1118 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, 1119 StringRef extInstSet, 1120 uint32_t opcode) { 1121 SmallVector<uint32_t, 4> operands; 1122 Location loc = op->getLoc(); 1123 1124 uint32_t resultID = 0; 1125 if (op->getNumResults() != 0) { 1126 uint32_t resultTypeID = 0; 1127 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) 1128 return failure(); 1129 operands.push_back(resultTypeID); 1130 1131 resultID = getNextID(); 1132 operands.push_back(resultID); 1133 valueIDMap[op->getResult(0)] = resultID; 1134 }; 1135 1136 for (Value operand : op->getOperands()) 1137 operands.push_back(getValueID(operand)); 1138 1139 (void)emitDebugLine(functionBody, loc); 1140 1141 if (extInstSet.empty()) { 1142 (void)encodeInstructionInto(functionBody, 1143 static_cast<spirv::Opcode>(opcode), operands); 1144 } else { 1145 (void)encodeExtensionInstruction(op, extInstSet, opcode, operands); 1146 } 1147 1148 if (op->getNumResults() != 0) { 1149 for (auto attr : op->getAttrs()) { 1150 if (failed(processDecoration(loc, resultID, attr))) 1151 return failure(); 1152 } 1153 } 1154 1155 return success(); 1156 } 1157 1158 LogicalResult Serializer::emitDecoration(uint32_t target, 1159 spirv::Decoration decoration, 1160 ArrayRef<uint32_t> params) { 1161 uint32_t wordCount = 3 + params.size(); 1162 decorations.push_back( 1163 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); 1164 decorations.push_back(target); 1165 decorations.push_back(static_cast<uint32_t>(decoration)); 1166 decorations.append(params.begin(), params.end()); 1167 return success(); 1168 } 1169 1170 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, 1171 Location loc) { 1172 if (!emitDebugInfo) 1173 return success(); 1174 1175 if (lastProcessedWasMergeInst) { 1176 lastProcessedWasMergeInst = false; 1177 return success(); 1178 } 1179 1180 auto fileLoc = loc.dyn_cast<FileLineColLoc>(); 1181 if (fileLoc) 1182 (void)encodeInstructionInto( 1183 binary, spirv::Opcode::OpLine, 1184 {fileID, fileLoc.getLine(), fileLoc.getColumn()}); 1185 return success(); 1186 } 1187 } // namespace spirv 1188 } // namespace mlir 1189