1 //===- Serializer.cpp - MLIR SPIR-V Serializer ----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the MLIR SPIR-V module to SPIR-V binary serializer. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Serializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Support/LogicalResult.h" 19 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 20 #include "llvm/ADT/Sequence.h" 21 #include "llvm/ADT/SmallPtrSet.h" 22 #include "llvm/ADT/StringExtras.h" 23 #include "llvm/ADT/TypeSwitch.h" 24 #include "llvm/ADT/bit.h" 25 #include "llvm/Support/Debug.h" 26 27 #define DEBUG_TYPE "spirv-serialization" 28 29 using namespace mlir; 30 31 /// Returns the merge block if the given `op` is a structured control flow op. 32 /// Otherwise returns nullptr. 33 static Block *getStructuredControlFlowOpMergeBlock(Operation *op) { 34 if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op)) 35 return selectionOp.getMergeBlock(); 36 if (auto loopOp = dyn_cast<spirv::LoopOp>(op)) 37 return loopOp.getMergeBlock(); 38 return nullptr; 39 } 40 41 /// Given a predecessor `block` for a block with arguments, returns the block 42 /// that should be used as the parent block for SPIR-V OpPhi instructions 43 /// corresponding to the block arguments. 44 static Block *getPhiIncomingBlock(Block *block) { 45 // If the predecessor block in question is the entry block for a 46 // spv.mlir.loop, we jump to this spv.mlir.loop from its enclosing block. 47 if (block->isEntryBlock()) { 48 if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) { 49 // Then the incoming parent block for OpPhi should be the merge block of 50 // the structured control flow op before this loop. 51 Operation *op = loopOp.getOperation(); 52 while ((op = op->getPrevNode()) != nullptr) 53 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op)) 54 return incomingBlock; 55 // Or the enclosing block itself if no structured control flow ops 56 // exists before this loop. 57 return loopOp->getBlock(); 58 } 59 } 60 61 // Otherwise, we jump from the given predecessor block. Try to see if there is 62 // a structured control flow op inside it. 63 for (Operation &op : llvm::reverse(block->getOperations())) { 64 if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op)) 65 return incomingBlock; 66 } 67 return block; 68 } 69 70 namespace mlir { 71 namespace spirv { 72 73 /// Encodes an SPIR-V instruction with the given `opcode` and `operands` into 74 /// the given `binary` vector. 75 LogicalResult encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, 76 spirv::Opcode op, 77 ArrayRef<uint32_t> operands) { 78 uint32_t wordCount = 1 + operands.size(); 79 binary.push_back(spirv::getPrefixedOpcode(wordCount, op)); 80 binary.append(operands.begin(), operands.end()); 81 return success(); 82 } 83 84 Serializer::Serializer(spirv::ModuleOp module, 85 const SerializationOptions &options) 86 : module(module), mlirBuilder(module.getContext()), options(options) {} 87 88 LogicalResult Serializer::serialize() { 89 LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n"); 90 91 if (failed(module.verify())) 92 return failure(); 93 94 // TODO: handle the other sections 95 processCapability(); 96 processExtension(); 97 processMemoryModel(); 98 processDebugInfo(); 99 100 // Iterate over the module body to serialize it. Assumptions are that there is 101 // only one basic block in the moduleOp 102 for (auto &op : *module.getBody()) { 103 if (failed(processOperation(&op))) { 104 return failure(); 105 } 106 } 107 108 LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n"); 109 return success(); 110 } 111 112 void Serializer::collect(SmallVectorImpl<uint32_t> &binary) { 113 auto moduleSize = spirv::kHeaderWordCount + capabilities.size() + 114 extensions.size() + extendedSets.size() + 115 memoryModel.size() + entryPoints.size() + 116 executionModes.size() + decorations.size() + 117 typesGlobalValues.size() + functions.size(); 118 119 binary.clear(); 120 binary.reserve(moduleSize); 121 122 spirv::appendModuleHeader(binary, module.vce_triple()->getVersion(), nextID); 123 binary.append(capabilities.begin(), capabilities.end()); 124 binary.append(extensions.begin(), extensions.end()); 125 binary.append(extendedSets.begin(), extendedSets.end()); 126 binary.append(memoryModel.begin(), memoryModel.end()); 127 binary.append(entryPoints.begin(), entryPoints.end()); 128 binary.append(executionModes.begin(), executionModes.end()); 129 binary.append(debug.begin(), debug.end()); 130 binary.append(names.begin(), names.end()); 131 binary.append(decorations.begin(), decorations.end()); 132 binary.append(typesGlobalValues.begin(), typesGlobalValues.end()); 133 binary.append(functions.begin(), functions.end()); 134 } 135 136 #ifndef NDEBUG 137 void Serializer::printValueIDMap(raw_ostream &os) { 138 os << "\n= Value <id> Map =\n\n"; 139 for (auto valueIDPair : valueIDMap) { 140 Value val = valueIDPair.first; 141 os << " " << val << " " 142 << "id = " << valueIDPair.second << ' '; 143 if (auto *op = val.getDefiningOp()) { 144 os << "from op '" << op->getName() << "'"; 145 } else if (auto arg = val.dyn_cast<BlockArgument>()) { 146 Block *block = arg.getOwner(); 147 os << "from argument of block " << block << ' '; 148 os << " in op '" << block->getParentOp()->getName() << "'"; 149 } 150 os << '\n'; 151 } 152 } 153 #endif 154 155 //===----------------------------------------------------------------------===// 156 // Module structure 157 //===----------------------------------------------------------------------===// 158 159 uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) { 160 auto funcID = funcIDMap.lookup(fnName); 161 if (!funcID) { 162 funcID = getNextID(); 163 funcIDMap[fnName] = funcID; 164 } 165 return funcID; 166 } 167 168 void Serializer::processCapability() { 169 for (auto cap : module.vce_triple()->getCapabilities()) 170 (void)encodeInstructionInto(capabilities, spirv::Opcode::OpCapability, 171 {static_cast<uint32_t>(cap)}); 172 } 173 174 void Serializer::processDebugInfo() { 175 if (!options.emitDebugInfo) 176 return; 177 auto fileLoc = module.getLoc().dyn_cast<FileLineColLoc>(); 178 auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>"; 179 fileID = getNextID(); 180 SmallVector<uint32_t, 16> operands; 181 operands.push_back(fileID); 182 (void)spirv::encodeStringLiteralInto(operands, fileName); 183 (void)encodeInstructionInto(debug, spirv::Opcode::OpString, operands); 184 // TODO: Encode more debug instructions. 185 } 186 187 void Serializer::processExtension() { 188 llvm::SmallVector<uint32_t, 16> extName; 189 for (spirv::Extension ext : module.vce_triple()->getExtensions()) { 190 extName.clear(); 191 (void)spirv::encodeStringLiteralInto(extName, 192 spirv::stringifyExtension(ext)); 193 (void)encodeInstructionInto(extensions, spirv::Opcode::OpExtension, 194 extName); 195 } 196 } 197 198 void Serializer::processMemoryModel() { 199 uint32_t mm = module->getAttrOfType<IntegerAttr>("memory_model").getInt(); 200 uint32_t am = module->getAttrOfType<IntegerAttr>("addressing_model").getInt(); 201 202 (void)encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, 203 {am, mm}); 204 } 205 206 LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID, 207 NamedAttribute attr) { 208 auto attrName = attr.getName().strref(); 209 auto decorationName = llvm::convertToCamelFromSnakeCase(attrName, true); 210 auto decoration = spirv::symbolizeDecoration(decorationName); 211 if (!decoration) { 212 return emitError( 213 loc, "non-argument attributes expected to have snake-case-ified " 214 "decoration name, unhandled attribute with name : ") 215 << attrName; 216 } 217 SmallVector<uint32_t, 1> args; 218 switch (decoration.getValue()) { 219 case spirv::Decoration::Binding: 220 case spirv::Decoration::DescriptorSet: 221 case spirv::Decoration::Location: 222 if (auto intAttr = attr.getValue().dyn_cast<IntegerAttr>()) { 223 args.push_back(intAttr.getValue().getZExtValue()); 224 break; 225 } 226 return emitError(loc, "expected integer attribute for ") << attrName; 227 case spirv::Decoration::BuiltIn: 228 if (auto strAttr = attr.getValue().dyn_cast<StringAttr>()) { 229 auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue()); 230 if (enumVal) { 231 args.push_back(static_cast<uint32_t>(enumVal.getValue())); 232 break; 233 } 234 return emitError(loc, "invalid ") 235 << attrName << " attribute " << strAttr.getValue(); 236 } 237 return emitError(loc, "expected string attribute for ") << attrName; 238 case spirv::Decoration::Aliased: 239 case spirv::Decoration::Flat: 240 case spirv::Decoration::NonReadable: 241 case spirv::Decoration::NonWritable: 242 case spirv::Decoration::NoPerspective: 243 case spirv::Decoration::Restrict: 244 case spirv::Decoration::RelaxedPrecision: 245 // For unit attributes, the args list has no values so we do nothing 246 if (auto unitAttr = attr.getValue().dyn_cast<UnitAttr>()) 247 break; 248 return emitError(loc, "expected unit attribute for ") << attrName; 249 default: 250 return emitError(loc, "unhandled decoration ") << decorationName; 251 } 252 return emitDecoration(resultID, decoration.getValue(), args); 253 } 254 255 LogicalResult Serializer::processName(uint32_t resultID, StringRef name) { 256 assert(!name.empty() && "unexpected empty string for OpName"); 257 if (!options.emitSymbolName) 258 return success(); 259 260 SmallVector<uint32_t, 4> nameOperands; 261 nameOperands.push_back(resultID); 262 if (failed(spirv::encodeStringLiteralInto(nameOperands, name))) 263 return failure(); 264 return encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands); 265 } 266 267 template <> 268 LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>( 269 Location loc, spirv::ArrayType type, uint32_t resultID) { 270 if (unsigned stride = type.getArrayStride()) { 271 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 272 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 273 } 274 return success(); 275 } 276 277 template <> 278 LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>( 279 Location loc, spirv::RuntimeArrayType type, uint32_t resultID) { 280 if (unsigned stride = type.getArrayStride()) { 281 // OpDecorate %arrayTypeSSA ArrayStride strideLiteral 282 return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride}); 283 } 284 return success(); 285 } 286 287 LogicalResult Serializer::processMemberDecoration( 288 uint32_t structID, 289 const spirv::StructType::MemberDecorationInfo &memberDecoration) { 290 SmallVector<uint32_t, 4> args( 291 {structID, memberDecoration.memberIndex, 292 static_cast<uint32_t>(memberDecoration.decoration)}); 293 if (memberDecoration.hasValue) { 294 args.push_back(memberDecoration.decorationValue); 295 } 296 return encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, 297 args); 298 } 299 300 //===----------------------------------------------------------------------===// 301 // Type 302 //===----------------------------------------------------------------------===// 303 304 // According to the SPIR-V spec "Validation Rules for Shader Capabilities": 305 // "Composite objects in the StorageBuffer, PhysicalStorageBuffer, Uniform, and 306 // PushConstant Storage Classes must be explicitly laid out." 307 bool Serializer::isInterfaceStructPtrType(Type type) const { 308 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 309 switch (ptrType.getStorageClass()) { 310 case spirv::StorageClass::PhysicalStorageBuffer: 311 case spirv::StorageClass::PushConstant: 312 case spirv::StorageClass::StorageBuffer: 313 case spirv::StorageClass::Uniform: 314 return ptrType.getPointeeType().isa<spirv::StructType>(); 315 default: 316 break; 317 } 318 } 319 return false; 320 } 321 322 LogicalResult Serializer::processType(Location loc, Type type, 323 uint32_t &typeID) { 324 // Maintains a set of names for nested identified struct types. This is used 325 // to properly serialize recursive references. 326 SetVector<StringRef> serializationCtx; 327 return processTypeImpl(loc, type, typeID, serializationCtx); 328 } 329 330 LogicalResult 331 Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID, 332 SetVector<StringRef> &serializationCtx) { 333 typeID = getTypeID(type); 334 if (typeID) 335 return success(); 336 337 typeID = getNextID(); 338 SmallVector<uint32_t, 4> operands; 339 340 operands.push_back(typeID); 341 auto typeEnum = spirv::Opcode::OpTypeVoid; 342 bool deferSerialization = false; 343 344 if ((type.isa<FunctionType>() && 345 succeeded(prepareFunctionType(loc, type.cast<FunctionType>(), typeEnum, 346 operands))) || 347 succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands, 348 deferSerialization, serializationCtx))) { 349 if (deferSerialization) 350 return success(); 351 352 typeIDMap[type] = typeID; 353 354 if (failed(encodeInstructionInto(typesGlobalValues, typeEnum, operands))) 355 return failure(); 356 357 if (recursiveStructInfos.count(type) != 0) { 358 // This recursive struct type is emitted already, now the OpTypePointer 359 // instructions referring to recursive references are emitted as well. 360 for (auto &ptrInfo : recursiveStructInfos[type]) { 361 // TODO: This might not work if more than 1 recursive reference is 362 // present in the struct. 363 SmallVector<uint32_t, 4> ptrOperands; 364 ptrOperands.push_back(ptrInfo.pointerTypeID); 365 ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass)); 366 ptrOperands.push_back(typeIDMap[type]); 367 368 if (failed(encodeInstructionInto( 369 typesGlobalValues, spirv::Opcode::OpTypePointer, ptrOperands))) 370 return failure(); 371 } 372 373 recursiveStructInfos[type].clear(); 374 } 375 376 return success(); 377 } 378 379 return failure(); 380 } 381 382 LogicalResult Serializer::prepareBasicType( 383 Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum, 384 SmallVectorImpl<uint32_t> &operands, bool &deferSerialization, 385 SetVector<StringRef> &serializationCtx) { 386 deferSerialization = false; 387 388 if (isVoidType(type)) { 389 typeEnum = spirv::Opcode::OpTypeVoid; 390 return success(); 391 } 392 393 if (auto intType = type.dyn_cast<IntegerType>()) { 394 if (intType.getWidth() == 1) { 395 typeEnum = spirv::Opcode::OpTypeBool; 396 return success(); 397 } 398 399 typeEnum = spirv::Opcode::OpTypeInt; 400 operands.push_back(intType.getWidth()); 401 // SPIR-V OpTypeInt "Signedness specifies whether there are signed semantics 402 // to preserve or validate. 403 // 0 indicates unsigned, or no signedness semantics 404 // 1 indicates signed semantics." 405 operands.push_back(intType.isSigned() ? 1 : 0); 406 return success(); 407 } 408 409 if (auto floatType = type.dyn_cast<FloatType>()) { 410 typeEnum = spirv::Opcode::OpTypeFloat; 411 operands.push_back(floatType.getWidth()); 412 return success(); 413 } 414 415 if (auto vectorType = type.dyn_cast<VectorType>()) { 416 uint32_t elementTypeID = 0; 417 if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID, 418 serializationCtx))) { 419 return failure(); 420 } 421 typeEnum = spirv::Opcode::OpTypeVector; 422 operands.push_back(elementTypeID); 423 operands.push_back(vectorType.getNumElements()); 424 return success(); 425 } 426 427 if (auto imageType = type.dyn_cast<spirv::ImageType>()) { 428 typeEnum = spirv::Opcode::OpTypeImage; 429 uint32_t sampledTypeID = 0; 430 if (failed(processType(loc, imageType.getElementType(), sampledTypeID))) 431 return failure(); 432 433 operands.push_back(sampledTypeID); 434 operands.push_back(static_cast<uint32_t>(imageType.getDim())); 435 operands.push_back(static_cast<uint32_t>(imageType.getDepthInfo())); 436 operands.push_back(static_cast<uint32_t>(imageType.getArrayedInfo())); 437 operands.push_back(static_cast<uint32_t>(imageType.getSamplingInfo())); 438 operands.push_back(static_cast<uint32_t>(imageType.getSamplerUseInfo())); 439 operands.push_back(static_cast<uint32_t>(imageType.getImageFormat())); 440 return success(); 441 } 442 443 if (auto arrayType = type.dyn_cast<spirv::ArrayType>()) { 444 typeEnum = spirv::Opcode::OpTypeArray; 445 uint32_t elementTypeID = 0; 446 if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID, 447 serializationCtx))) { 448 return failure(); 449 } 450 operands.push_back(elementTypeID); 451 if (auto elementCountID = prepareConstantInt( 452 loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) { 453 operands.push_back(elementCountID); 454 } 455 return processTypeDecoration(loc, arrayType, resultID); 456 } 457 458 if (auto ptrType = type.dyn_cast<spirv::PointerType>()) { 459 uint32_t pointeeTypeID = 0; 460 spirv::StructType pointeeStruct = 461 ptrType.getPointeeType().dyn_cast<spirv::StructType>(); 462 463 if (pointeeStruct && pointeeStruct.isIdentified() && 464 serializationCtx.count(pointeeStruct.getIdentifier()) != 0) { 465 // A recursive reference to an enclosing struct is found. 466 // 467 // 1. Prepare an OpTypeForwardPointer with resultID and the ptr storage 468 // class as operands. 469 SmallVector<uint32_t, 2> forwardPtrOperands; 470 forwardPtrOperands.push_back(resultID); 471 forwardPtrOperands.push_back( 472 static_cast<uint32_t>(ptrType.getStorageClass())); 473 474 (void)encodeInstructionInto(typesGlobalValues, 475 spirv::Opcode::OpTypeForwardPointer, 476 forwardPtrOperands); 477 478 // 2. Find the pointee (enclosing) struct. 479 auto structType = spirv::StructType::getIdentified( 480 module.getContext(), pointeeStruct.getIdentifier()); 481 482 if (!structType) 483 return failure(); 484 485 // 3. Mark the OpTypePointer that is supposed to be emitted by this call 486 // as deferred. 487 deferSerialization = true; 488 489 // 4. Record the info needed to emit the deferred OpTypePointer 490 // instruction when the enclosing struct is completely serialized. 491 recursiveStructInfos[structType].push_back( 492 {resultID, ptrType.getStorageClass()}); 493 } else { 494 if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID, 495 serializationCtx))) 496 return failure(); 497 } 498 499 typeEnum = spirv::Opcode::OpTypePointer; 500 operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass())); 501 operands.push_back(pointeeTypeID); 502 503 if (isInterfaceStructPtrType(ptrType)) { 504 if (failed(emitDecoration(getTypeID(pointeeStruct), 505 spirv::Decoration::Block))) 506 return emitError(loc, "cannot decorate ") 507 << pointeeStruct << " with Block decoration"; 508 } 509 510 return success(); 511 } 512 513 if (auto runtimeArrayType = type.dyn_cast<spirv::RuntimeArrayType>()) { 514 uint32_t elementTypeID = 0; 515 if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(), 516 elementTypeID, serializationCtx))) { 517 return failure(); 518 } 519 typeEnum = spirv::Opcode::OpTypeRuntimeArray; 520 operands.push_back(elementTypeID); 521 return processTypeDecoration(loc, runtimeArrayType, resultID); 522 } 523 524 if (auto sampledImageType = type.dyn_cast<spirv::SampledImageType>()) { 525 typeEnum = spirv::Opcode::OpTypeSampledImage; 526 uint32_t imageTypeID = 0; 527 if (failed( 528 processType(loc, sampledImageType.getImageType(), imageTypeID))) { 529 return failure(); 530 } 531 operands.push_back(imageTypeID); 532 return success(); 533 } 534 535 if (auto structType = type.dyn_cast<spirv::StructType>()) { 536 if (structType.isIdentified()) { 537 (void)processName(resultID, structType.getIdentifier()); 538 serializationCtx.insert(structType.getIdentifier()); 539 } 540 541 bool hasOffset = structType.hasOffset(); 542 for (auto elementIndex : 543 llvm::seq<uint32_t>(0, structType.getNumElements())) { 544 uint32_t elementTypeID = 0; 545 if (failed(processTypeImpl(loc, structType.getElementType(elementIndex), 546 elementTypeID, serializationCtx))) { 547 return failure(); 548 } 549 operands.push_back(elementTypeID); 550 if (hasOffset) { 551 // Decorate each struct member with an offset 552 spirv::StructType::MemberDecorationInfo offsetDecoration{ 553 elementIndex, /*hasValue=*/1, spirv::Decoration::Offset, 554 static_cast<uint32_t>(structType.getMemberOffset(elementIndex))}; 555 if (failed(processMemberDecoration(resultID, offsetDecoration))) { 556 return emitError(loc, "cannot decorate ") 557 << elementIndex << "-th member of " << structType 558 << " with its offset"; 559 } 560 } 561 } 562 SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations; 563 structType.getMemberDecorations(memberDecorations); 564 565 for (auto &memberDecoration : memberDecorations) { 566 if (failed(processMemberDecoration(resultID, memberDecoration))) { 567 return emitError(loc, "cannot decorate ") 568 << static_cast<uint32_t>(memberDecoration.memberIndex) 569 << "-th member of " << structType << " with " 570 << stringifyDecoration(memberDecoration.decoration); 571 } 572 } 573 574 typeEnum = spirv::Opcode::OpTypeStruct; 575 576 if (structType.isIdentified()) 577 serializationCtx.remove(structType.getIdentifier()); 578 579 return success(); 580 } 581 582 if (auto cooperativeMatrixType = 583 type.dyn_cast<spirv::CooperativeMatrixNVType>()) { 584 uint32_t elementTypeID = 0; 585 if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(), 586 elementTypeID, serializationCtx))) { 587 return failure(); 588 } 589 typeEnum = spirv::Opcode::OpTypeCooperativeMatrixNV; 590 auto getConstantOp = [&](uint32_t id) { 591 auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id); 592 return prepareConstantInt(loc, attr); 593 }; 594 operands.push_back(elementTypeID); 595 operands.push_back( 596 getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope()))); 597 operands.push_back(getConstantOp(cooperativeMatrixType.getRows())); 598 operands.push_back(getConstantOp(cooperativeMatrixType.getColumns())); 599 return success(); 600 } 601 602 if (auto matrixType = type.dyn_cast<spirv::MatrixType>()) { 603 uint32_t elementTypeID = 0; 604 if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID, 605 serializationCtx))) { 606 return failure(); 607 } 608 typeEnum = spirv::Opcode::OpTypeMatrix; 609 operands.push_back(elementTypeID); 610 operands.push_back(matrixType.getNumColumns()); 611 return success(); 612 } 613 614 // TODO: Handle other types. 615 return emitError(loc, "unhandled type in serialization: ") << type; 616 } 617 618 LogicalResult 619 Serializer::prepareFunctionType(Location loc, FunctionType type, 620 spirv::Opcode &typeEnum, 621 SmallVectorImpl<uint32_t> &operands) { 622 typeEnum = spirv::Opcode::OpTypeFunction; 623 assert(type.getNumResults() <= 1 && 624 "serialization supports only a single return value"); 625 uint32_t resultID = 0; 626 if (failed(processType( 627 loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(), 628 resultID))) { 629 return failure(); 630 } 631 operands.push_back(resultID); 632 for (auto &res : type.getInputs()) { 633 uint32_t argTypeID = 0; 634 if (failed(processType(loc, res, argTypeID))) { 635 return failure(); 636 } 637 operands.push_back(argTypeID); 638 } 639 return success(); 640 } 641 642 //===----------------------------------------------------------------------===// 643 // Constant 644 //===----------------------------------------------------------------------===// 645 646 uint32_t Serializer::prepareConstant(Location loc, Type constType, 647 Attribute valueAttr) { 648 if (auto id = prepareConstantScalar(loc, valueAttr)) { 649 return id; 650 } 651 652 // This is a composite literal. We need to handle each component separately 653 // and then emit an OpConstantComposite for the whole. 654 655 if (auto id = getConstantID(valueAttr)) { 656 return id; 657 } 658 659 uint32_t typeID = 0; 660 if (failed(processType(loc, constType, typeID))) { 661 return 0; 662 } 663 664 uint32_t resultID = 0; 665 if (auto attr = valueAttr.dyn_cast<DenseElementsAttr>()) { 666 int rank = attr.getType().dyn_cast<ShapedType>().getRank(); 667 SmallVector<uint64_t, 4> index(rank); 668 resultID = prepareDenseElementsConstant(loc, constType, attr, 669 /*dim=*/0, index); 670 } else if (auto arrayAttr = valueAttr.dyn_cast<ArrayAttr>()) { 671 resultID = prepareArrayConstant(loc, constType, arrayAttr); 672 } 673 674 if (resultID == 0) { 675 emitError(loc, "cannot serialize attribute: ") << valueAttr; 676 return 0; 677 } 678 679 constIDMap[valueAttr] = resultID; 680 return resultID; 681 } 682 683 uint32_t Serializer::prepareArrayConstant(Location loc, Type constType, 684 ArrayAttr attr) { 685 uint32_t typeID = 0; 686 if (failed(processType(loc, constType, typeID))) { 687 return 0; 688 } 689 690 uint32_t resultID = getNextID(); 691 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 692 operands.reserve(attr.size() + 2); 693 auto elementType = constType.cast<spirv::ArrayType>().getElementType(); 694 for (Attribute elementAttr : attr) { 695 if (auto elementID = prepareConstant(loc, elementType, elementAttr)) { 696 operands.push_back(elementID); 697 } else { 698 return 0; 699 } 700 } 701 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 702 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 703 704 return resultID; 705 } 706 707 // TODO: Turn the below function into iterative function, instead of 708 // recursive function. 709 uint32_t 710 Serializer::prepareDenseElementsConstant(Location loc, Type constType, 711 DenseElementsAttr valueAttr, int dim, 712 MutableArrayRef<uint64_t> index) { 713 auto shapedType = valueAttr.getType().dyn_cast<ShapedType>(); 714 assert(dim <= shapedType.getRank()); 715 if (shapedType.getRank() == dim) { 716 if (auto attr = valueAttr.dyn_cast<DenseIntElementsAttr>()) { 717 return attr.getType().getElementType().isInteger(1) 718 ? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index]) 719 : prepareConstantInt(loc, 720 attr.getValues<IntegerAttr>()[index]); 721 } 722 if (auto attr = valueAttr.dyn_cast<DenseFPElementsAttr>()) { 723 return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]); 724 } 725 return 0; 726 } 727 728 uint32_t typeID = 0; 729 if (failed(processType(loc, constType, typeID))) { 730 return 0; 731 } 732 733 uint32_t resultID = getNextID(); 734 SmallVector<uint32_t, 4> operands = {typeID, resultID}; 735 operands.reserve(shapedType.getDimSize(dim) + 2); 736 auto elementType = constType.cast<spirv::CompositeType>().getElementType(0); 737 for (int i = 0; i < shapedType.getDimSize(dim); ++i) { 738 index[dim] = i; 739 if (auto elementID = prepareDenseElementsConstant( 740 loc, elementType, valueAttr, dim + 1, index)) { 741 operands.push_back(elementID); 742 } else { 743 return 0; 744 } 745 } 746 spirv::Opcode opcode = spirv::Opcode::OpConstantComposite; 747 (void)encodeInstructionInto(typesGlobalValues, opcode, operands); 748 749 return resultID; 750 } 751 752 uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr, 753 bool isSpec) { 754 if (auto floatAttr = valueAttr.dyn_cast<FloatAttr>()) { 755 return prepareConstantFp(loc, floatAttr, isSpec); 756 } 757 if (auto boolAttr = valueAttr.dyn_cast<BoolAttr>()) { 758 return prepareConstantBool(loc, boolAttr, isSpec); 759 } 760 if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) { 761 return prepareConstantInt(loc, intAttr, isSpec); 762 } 763 764 return 0; 765 } 766 767 uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr, 768 bool isSpec) { 769 if (!isSpec) { 770 // We can de-duplicate normal constants, but not specialization constants. 771 if (auto id = getConstantID(boolAttr)) { 772 return id; 773 } 774 } 775 776 // Process the type for this bool literal 777 uint32_t typeID = 0; 778 if (failed(processType(loc, boolAttr.getType(), typeID))) { 779 return 0; 780 } 781 782 auto resultID = getNextID(); 783 auto opcode = boolAttr.getValue() 784 ? (isSpec ? spirv::Opcode::OpSpecConstantTrue 785 : spirv::Opcode::OpConstantTrue) 786 : (isSpec ? spirv::Opcode::OpSpecConstantFalse 787 : spirv::Opcode::OpConstantFalse); 788 (void)encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID}); 789 790 if (!isSpec) { 791 constIDMap[boolAttr] = resultID; 792 } 793 return resultID; 794 } 795 796 uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr, 797 bool isSpec) { 798 if (!isSpec) { 799 // We can de-duplicate normal constants, but not specialization constants. 800 if (auto id = getConstantID(intAttr)) { 801 return id; 802 } 803 } 804 805 // Process the type for this integer literal 806 uint32_t typeID = 0; 807 if (failed(processType(loc, intAttr.getType(), typeID))) { 808 return 0; 809 } 810 811 auto resultID = getNextID(); 812 APInt value = intAttr.getValue(); 813 unsigned bitwidth = value.getBitWidth(); 814 bool isSigned = value.isSignedIntN(bitwidth); 815 816 auto opcode = 817 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 818 819 switch (bitwidth) { 820 // According to SPIR-V spec, "When the type's bit width is less than 821 // 32-bits, the literal's value appears in the low-order bits of the word, 822 // and the high-order bits must be 0 for a floating-point type, or 0 for an 823 // integer type with Signedness of 0, or sign extended when Signedness 824 // is 1." 825 case 32: 826 case 16: 827 case 8: { 828 uint32_t word = 0; 829 if (isSigned) { 830 word = static_cast<int32_t>(value.getSExtValue()); 831 } else { 832 word = static_cast<uint32_t>(value.getZExtValue()); 833 } 834 (void)encodeInstructionInto(typesGlobalValues, opcode, 835 {typeID, resultID, word}); 836 } break; 837 // According to SPIR-V spec: "When the type's bit width is larger than one 838 // word, the literal’s low-order words appear first." 839 case 64: { 840 struct DoubleWord { 841 uint32_t word1; 842 uint32_t word2; 843 } words; 844 if (isSigned) { 845 words = llvm::bit_cast<DoubleWord>(value.getSExtValue()); 846 } else { 847 words = llvm::bit_cast<DoubleWord>(value.getZExtValue()); 848 } 849 (void)encodeInstructionInto(typesGlobalValues, opcode, 850 {typeID, resultID, words.word1, words.word2}); 851 } break; 852 default: { 853 std::string valueStr; 854 llvm::raw_string_ostream rss(valueStr); 855 value.print(rss, /*isSigned=*/false); 856 857 emitError(loc, "cannot serialize ") 858 << bitwidth << "-bit integer literal: " << rss.str(); 859 return 0; 860 } 861 } 862 863 if (!isSpec) { 864 constIDMap[intAttr] = resultID; 865 } 866 return resultID; 867 } 868 869 uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr, 870 bool isSpec) { 871 if (!isSpec) { 872 // We can de-duplicate normal constants, but not specialization constants. 873 if (auto id = getConstantID(floatAttr)) { 874 return id; 875 } 876 } 877 878 // Process the type for this float literal 879 uint32_t typeID = 0; 880 if (failed(processType(loc, floatAttr.getType(), typeID))) { 881 return 0; 882 } 883 884 auto resultID = getNextID(); 885 APFloat value = floatAttr.getValue(); 886 APInt intValue = value.bitcastToAPInt(); 887 888 auto opcode = 889 isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant; 890 891 if (&value.getSemantics() == &APFloat::IEEEsingle()) { 892 uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat()); 893 (void)encodeInstructionInto(typesGlobalValues, opcode, 894 {typeID, resultID, word}); 895 } else if (&value.getSemantics() == &APFloat::IEEEdouble()) { 896 struct DoubleWord { 897 uint32_t word1; 898 uint32_t word2; 899 } words = llvm::bit_cast<DoubleWord>(value.convertToDouble()); 900 (void)encodeInstructionInto(typesGlobalValues, opcode, 901 {typeID, resultID, words.word1, words.word2}); 902 } else if (&value.getSemantics() == &APFloat::IEEEhalf()) { 903 uint32_t word = 904 static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue()); 905 (void)encodeInstructionInto(typesGlobalValues, opcode, 906 {typeID, resultID, word}); 907 } else { 908 std::string valueStr; 909 llvm::raw_string_ostream rss(valueStr); 910 value.print(rss); 911 912 emitError(loc, "cannot serialize ") 913 << floatAttr.getType() << "-typed float literal: " << rss.str(); 914 return 0; 915 } 916 917 if (!isSpec) { 918 constIDMap[floatAttr] = resultID; 919 } 920 return resultID; 921 } 922 923 //===----------------------------------------------------------------------===// 924 // Control flow 925 //===----------------------------------------------------------------------===// 926 927 uint32_t Serializer::getOrCreateBlockID(Block *block) { 928 if (uint32_t id = getBlockID(block)) 929 return id; 930 return blockIDMap[block] = getNextID(); 931 } 932 933 LogicalResult 934 Serializer::processBlock(Block *block, bool omitLabel, 935 function_ref<void()> actionBeforeTerminator) { 936 LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n"); 937 LLVM_DEBUG(block->print(llvm::dbgs())); 938 LLVM_DEBUG(llvm::dbgs() << '\n'); 939 if (!omitLabel) { 940 uint32_t blockID = getOrCreateBlockID(block); 941 LLVM_DEBUG(llvm::dbgs() 942 << "[block] " << block << " (id = " << blockID << ")\n"); 943 944 // Emit OpLabel for this block. 945 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, 946 {blockID}); 947 } 948 949 // Emit OpPhi instructions for block arguments, if any. 950 if (failed(emitPhiForBlockArguments(block))) 951 return failure(); 952 953 // Process each op in this block except the terminator. 954 for (auto &op : llvm::make_range(block->begin(), std::prev(block->end()))) { 955 if (failed(processOperation(&op))) 956 return failure(); 957 } 958 959 // Process the terminator. 960 if (actionBeforeTerminator) 961 actionBeforeTerminator(); 962 if (failed(processOperation(&block->back()))) 963 return failure(); 964 965 return success(); 966 } 967 968 LogicalResult Serializer::emitPhiForBlockArguments(Block *block) { 969 // Nothing to do if this block has no arguments or it's the entry block, which 970 // always has the same arguments as the function signature. 971 if (block->args_empty() || block->isEntryBlock()) 972 return success(); 973 974 // If the block has arguments, we need to create SPIR-V OpPhi instructions. 975 // A SPIR-V OpPhi instruction is of the syntax: 976 // OpPhi | result type | result <id> | (value <id>, parent block <id>) pair 977 // So we need to collect all predecessor blocks and the arguments they send 978 // to this block. 979 SmallVector<std::pair<Block *, OperandRange>, 4> predecessors; 980 for (Block *predecessor : block->getPredecessors()) { 981 auto *terminator = predecessor->getTerminator(); 982 // The predecessor here is the immediate one according to MLIR's IR 983 // structure. It does not directly map to the incoming parent block for the 984 // OpPhi instructions at SPIR-V binary level. This is because structured 985 // control flow ops are serialized to multiple SPIR-V blocks. If there is a 986 // spv.mlir.selection/spv.mlir.loop op in the MLIR predecessor block, the 987 // branch op jumping to the OpPhi's block then resides in the previous 988 // structured control flow op's merge block. 989 predecessor = getPhiIncomingBlock(predecessor); 990 if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) { 991 predecessors.emplace_back(predecessor, branchOp.getOperands()); 992 } else if (auto branchCondOp = 993 dyn_cast<spirv::BranchConditionalOp>(terminator)) { 994 Optional<OperandRange> blockOperands; 995 996 for (auto successorIdx : 997 llvm::seq<unsigned>(0, predecessor->getNumSuccessors())) 998 if (predecessor->getSuccessors()[successorIdx] == block) { 999 blockOperands = branchCondOp.getSuccessorOperands(successorIdx); 1000 break; 1001 } 1002 1003 assert(blockOperands && !blockOperands->empty() && 1004 "expected non-empty block operand range"); 1005 predecessors.emplace_back(predecessor, *blockOperands); 1006 } else { 1007 return terminator->emitError("unimplemented terminator for Phi creation"); 1008 } 1009 } 1010 1011 // Then create OpPhi instruction for each of the block argument. 1012 for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) { 1013 BlockArgument arg = block->getArgument(argIndex); 1014 1015 // Get the type <id> and result <id> for this OpPhi instruction. 1016 uint32_t phiTypeID = 0; 1017 if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID))) 1018 return failure(); 1019 uint32_t phiID = getNextID(); 1020 1021 LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' ' 1022 << arg << " (id = " << phiID << ")\n"); 1023 1024 // Prepare the (value <id>, parent block <id>) pairs. 1025 SmallVector<uint32_t, 8> phiArgs; 1026 phiArgs.push_back(phiTypeID); 1027 phiArgs.push_back(phiID); 1028 1029 for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) { 1030 Value value = predecessors[predIndex].second[argIndex]; 1031 uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first); 1032 LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId 1033 << ") value " << value << ' '); 1034 // Each pair is a value <id> ... 1035 uint32_t valueId = getValueID(value); 1036 if (valueId == 0) { 1037 // The op generating this value hasn't been visited yet so we don't have 1038 // an <id> assigned yet. Record this to fix up later. 1039 LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n"); 1040 deferredPhiValues[value].push_back(functionBody.size() + 1 + 1041 phiArgs.size()); 1042 } else { 1043 LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n"); 1044 } 1045 phiArgs.push_back(valueId); 1046 // ... and a parent block <id>. 1047 phiArgs.push_back(predBlockId); 1048 } 1049 1050 (void)encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs); 1051 valueIDMap[arg] = phiID; 1052 } 1053 1054 return success(); 1055 } 1056 1057 //===----------------------------------------------------------------------===// 1058 // Operation 1059 //===----------------------------------------------------------------------===// 1060 1061 LogicalResult Serializer::encodeExtensionInstruction( 1062 Operation *op, StringRef extensionSetName, uint32_t extensionOpcode, 1063 ArrayRef<uint32_t> operands) { 1064 // Check if the extension has been imported. 1065 auto &setID = extendedInstSetIDMap[extensionSetName]; 1066 if (!setID) { 1067 setID = getNextID(); 1068 SmallVector<uint32_t, 16> importOperands; 1069 importOperands.push_back(setID); 1070 if (failed( 1071 spirv::encodeStringLiteralInto(importOperands, extensionSetName)) || 1072 failed(encodeInstructionInto( 1073 extendedSets, spirv::Opcode::OpExtInstImport, importOperands))) { 1074 return failure(); 1075 } 1076 } 1077 1078 // The first two operands are the result type <id> and result <id>. The set 1079 // <id> and the opcode need to be insert after this. 1080 if (operands.size() < 2) { 1081 return op->emitError("extended instructions must have a result encoding"); 1082 } 1083 SmallVector<uint32_t, 8> extInstOperands; 1084 extInstOperands.reserve(operands.size() + 2); 1085 extInstOperands.append(operands.begin(), std::next(operands.begin(), 2)); 1086 extInstOperands.push_back(setID); 1087 extInstOperands.push_back(extensionOpcode); 1088 extInstOperands.append(std::next(operands.begin(), 2), operands.end()); 1089 return encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst, 1090 extInstOperands); 1091 } 1092 1093 LogicalResult Serializer::processOperation(Operation *opInst) { 1094 LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n"); 1095 1096 // First dispatch the ops that do not directly mirror an instruction from 1097 // the SPIR-V spec. 1098 return TypeSwitch<Operation *, LogicalResult>(opInst) 1099 .Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); }) 1100 .Case([&](spirv::BranchOp op) { return processBranchOp(op); }) 1101 .Case([&](spirv::BranchConditionalOp op) { 1102 return processBranchConditionalOp(op); 1103 }) 1104 .Case([&](spirv::ConstantOp op) { return processConstantOp(op); }) 1105 .Case([&](spirv::FuncOp op) { return processFuncOp(op); }) 1106 .Case([&](spirv::GlobalVariableOp op) { 1107 return processGlobalVariableOp(op); 1108 }) 1109 .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) 1110 .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) 1111 .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) 1112 .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) 1113 .Case([&](spirv::SpecConstantCompositeOp op) { 1114 return processSpecConstantCompositeOp(op); 1115 }) 1116 .Case([&](spirv::SpecConstantOperationOp op) { 1117 return processSpecConstantOperationOp(op); 1118 }) 1119 .Case([&](spirv::UndefOp op) { return processUndefOp(op); }) 1120 .Case([&](spirv::VariableOp op) { return processVariableOp(op); }) 1121 1122 // Then handle all the ops that directly mirror SPIR-V instructions with 1123 // auto-generated methods. 1124 .Default( 1125 [&](Operation *op) { return dispatchToAutogenSerialization(op); }); 1126 } 1127 1128 LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op, 1129 StringRef extInstSet, 1130 uint32_t opcode) { 1131 SmallVector<uint32_t, 4> operands; 1132 Location loc = op->getLoc(); 1133 1134 uint32_t resultID = 0; 1135 if (op->getNumResults() != 0) { 1136 uint32_t resultTypeID = 0; 1137 if (failed(processType(loc, op->getResult(0).getType(), resultTypeID))) 1138 return failure(); 1139 operands.push_back(resultTypeID); 1140 1141 resultID = getNextID(); 1142 operands.push_back(resultID); 1143 valueIDMap[op->getResult(0)] = resultID; 1144 }; 1145 1146 for (Value operand : op->getOperands()) 1147 operands.push_back(getValueID(operand)); 1148 1149 (void)emitDebugLine(functionBody, loc); 1150 1151 if (extInstSet.empty()) { 1152 (void)encodeInstructionInto(functionBody, 1153 static_cast<spirv::Opcode>(opcode), operands); 1154 } else { 1155 (void)encodeExtensionInstruction(op, extInstSet, opcode, operands); 1156 } 1157 1158 if (op->getNumResults() != 0) { 1159 for (auto attr : op->getAttrs()) { 1160 if (failed(processDecoration(loc, resultID, attr))) 1161 return failure(); 1162 } 1163 } 1164 1165 return success(); 1166 } 1167 1168 LogicalResult Serializer::emitDecoration(uint32_t target, 1169 spirv::Decoration decoration, 1170 ArrayRef<uint32_t> params) { 1171 uint32_t wordCount = 3 + params.size(); 1172 decorations.push_back( 1173 spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate)); 1174 decorations.push_back(target); 1175 decorations.push_back(static_cast<uint32_t>(decoration)); 1176 decorations.append(params.begin(), params.end()); 1177 return success(); 1178 } 1179 1180 LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary, 1181 Location loc) { 1182 if (!options.emitDebugInfo) 1183 return success(); 1184 1185 if (lastProcessedWasMergeInst) { 1186 lastProcessedWasMergeInst = false; 1187 return success(); 1188 } 1189 1190 auto fileLoc = loc.dyn_cast<FileLineColLoc>(); 1191 if (fileLoc) 1192 (void)encodeInstructionInto( 1193 binary, spirv::Opcode::OpLine, 1194 {fileID, fileLoc.getLine(), fileLoc.getColumn()}); 1195 return success(); 1196 } 1197 } // namespace spirv 1198 } // namespace mlir 1199