1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Deserializer methods for SPIR-V binary instructions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Deserializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/Location.h" 18 #include "mlir/Target/SPIRV/SPIRVBinaryUtils.h" 19 #include "llvm/ADT/STLExtras.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "llvm/Support/Debug.h" 22 23 using namespace mlir; 24 25 #define DEBUG_TYPE "spirv-deserialization" 26 27 //===----------------------------------------------------------------------===// 28 // Utility Functions 29 //===----------------------------------------------------------------------===// 30 31 /// Extracts the opcode from the given first word of a SPIR-V instruction. 32 static inline spirv::Opcode extractOpcode(uint32_t word) { 33 return static_cast<spirv::Opcode>(word & 0xffff); 34 } 35 36 //===----------------------------------------------------------------------===// 37 // Instruction 38 //===----------------------------------------------------------------------===// 39 40 Value spirv::Deserializer::getValue(uint32_t id) { 41 if (auto constInfo = getConstant(id)) { 42 // Materialize a `spv.Constant` op at every use site. 43 return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second, 44 constInfo->first); 45 } 46 if (auto varOp = getGlobalVariable(id)) { 47 auto addressOfOp = opBuilder.create<spirv::AddressOfOp>( 48 unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation())); 49 return addressOfOp.pointer(); 50 } 51 if (auto constOp = getSpecConstant(id)) { 52 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( 53 unknownLoc, constOp.default_value().getType(), 54 SymbolRefAttr::get(constOp.getOperation())); 55 return referenceOfOp.reference(); 56 } 57 if (auto constCompositeOp = getSpecConstantComposite(id)) { 58 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( 59 unknownLoc, constCompositeOp.type(), 60 SymbolRefAttr::get(constCompositeOp.getOperation())); 61 return referenceOfOp.reference(); 62 } 63 if (auto specConstOperationInfo = getSpecConstantOperation(id)) { 64 return materializeSpecConstantOperation( 65 id, specConstOperationInfo->enclodesOpcode, 66 specConstOperationInfo->resultTypeID, 67 specConstOperationInfo->enclosedOpOperands); 68 } 69 if (auto undef = getUndefType(id)) { 70 return opBuilder.create<spirv::UndefOp>(unknownLoc, undef); 71 } 72 return valueMap.lookup(id); 73 } 74 75 LogicalResult 76 spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode, 77 ArrayRef<uint32_t> &operands, 78 Optional<spirv::Opcode> expectedOpcode) { 79 auto binarySize = binary.size(); 80 if (curOffset >= binarySize) { 81 return emitError(unknownLoc, "expected ") 82 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) 83 : "more") 84 << " instruction"; 85 } 86 87 // For each instruction, get its word count from the first word to slice it 88 // from the stream properly, and then dispatch to the instruction handler. 89 90 uint32_t wordCount = binary[curOffset] >> 16; 91 92 if (wordCount == 0) 93 return emitError(unknownLoc, "word count cannot be zero"); 94 95 uint32_t nextOffset = curOffset + wordCount; 96 if (nextOffset > binarySize) 97 return emitError(unknownLoc, "insufficient words for the last instruction"); 98 99 opcode = extractOpcode(binary[curOffset]); 100 operands = binary.slice(curOffset + 1, wordCount - 1); 101 curOffset = nextOffset; 102 return success(); 103 } 104 105 LogicalResult spirv::Deserializer::processInstruction( 106 spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) { 107 LLVM_DEBUG(logger.startLine() << "[inst] processing instruction " 108 << spirv::stringifyOpcode(opcode) << "\n"); 109 110 // First dispatch all the instructions whose opcode does not correspond to 111 // those that have a direct mirror in the SPIR-V dialect 112 switch (opcode) { 113 case spirv::Opcode::OpCapability: 114 return processCapability(operands); 115 case spirv::Opcode::OpExtension: 116 return processExtension(operands); 117 case spirv::Opcode::OpExtInst: 118 return processExtInst(operands); 119 case spirv::Opcode::OpExtInstImport: 120 return processExtInstImport(operands); 121 case spirv::Opcode::OpMemberName: 122 return processMemberName(operands); 123 case spirv::Opcode::OpMemoryModel: 124 return processMemoryModel(operands); 125 case spirv::Opcode::OpEntryPoint: 126 case spirv::Opcode::OpExecutionMode: 127 if (deferInstructions) { 128 deferredInstructions.emplace_back(opcode, operands); 129 return success(); 130 } 131 break; 132 case spirv::Opcode::OpVariable: 133 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) { 134 return processGlobalVariable(operands); 135 } 136 break; 137 case spirv::Opcode::OpLine: 138 return processDebugLine(operands); 139 case spirv::Opcode::OpNoLine: 140 clearDebugLine(); 141 return success(); 142 case spirv::Opcode::OpName: 143 return processName(operands); 144 case spirv::Opcode::OpString: 145 return processDebugString(operands); 146 case spirv::Opcode::OpModuleProcessed: 147 case spirv::Opcode::OpSource: 148 case spirv::Opcode::OpSourceContinued: 149 case spirv::Opcode::OpSourceExtension: 150 // TODO: This is debug information embedded in the binary which should be 151 // translated into the spv.module. 152 return success(); 153 case spirv::Opcode::OpTypeVoid: 154 case spirv::Opcode::OpTypeBool: 155 case spirv::Opcode::OpTypeInt: 156 case spirv::Opcode::OpTypeFloat: 157 case spirv::Opcode::OpTypeVector: 158 case spirv::Opcode::OpTypeMatrix: 159 case spirv::Opcode::OpTypeArray: 160 case spirv::Opcode::OpTypeFunction: 161 case spirv::Opcode::OpTypeImage: 162 case spirv::Opcode::OpTypeSampledImage: 163 case spirv::Opcode::OpTypeRuntimeArray: 164 case spirv::Opcode::OpTypeStruct: 165 case spirv::Opcode::OpTypePointer: 166 case spirv::Opcode::OpTypeCooperativeMatrixNV: 167 return processType(opcode, operands); 168 case spirv::Opcode::OpTypeForwardPointer: 169 return processTypeForwardPointer(operands); 170 case spirv::Opcode::OpConstant: 171 return processConstant(operands, /*isSpec=*/false); 172 case spirv::Opcode::OpSpecConstant: 173 return processConstant(operands, /*isSpec=*/true); 174 case spirv::Opcode::OpConstantComposite: 175 return processConstantComposite(operands); 176 case spirv::Opcode::OpSpecConstantComposite: 177 return processSpecConstantComposite(operands); 178 case spirv::Opcode::OpSpecConstantOp: 179 return processSpecConstantOperation(operands); 180 case spirv::Opcode::OpConstantTrue: 181 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); 182 case spirv::Opcode::OpSpecConstantTrue: 183 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); 184 case spirv::Opcode::OpConstantFalse: 185 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); 186 case spirv::Opcode::OpSpecConstantFalse: 187 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); 188 case spirv::Opcode::OpConstantNull: 189 return processConstantNull(operands); 190 case spirv::Opcode::OpDecorate: 191 return processDecoration(operands); 192 case spirv::Opcode::OpMemberDecorate: 193 return processMemberDecoration(operands); 194 case spirv::Opcode::OpFunction: 195 return processFunction(operands); 196 case spirv::Opcode::OpLabel: 197 return processLabel(operands); 198 case spirv::Opcode::OpBranch: 199 return processBranch(operands); 200 case spirv::Opcode::OpBranchConditional: 201 return processBranchConditional(operands); 202 case spirv::Opcode::OpSelectionMerge: 203 return processSelectionMerge(operands); 204 case spirv::Opcode::OpLoopMerge: 205 return processLoopMerge(operands); 206 case spirv::Opcode::OpPhi: 207 return processPhi(operands); 208 case spirv::Opcode::OpUndef: 209 return processUndef(operands); 210 default: 211 break; 212 } 213 return dispatchToAutogenDeserialization(opcode, operands); 214 } 215 216 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr( 217 ArrayRef<uint32_t> words, StringRef opName, bool hasResult, 218 unsigned numOperands) { 219 SmallVector<Type, 1> resultTypes; 220 uint32_t valueID = 0; 221 222 size_t wordIndex = 0; 223 if (hasResult) { 224 if (wordIndex >= words.size()) 225 return emitError(unknownLoc, 226 "expected result type <id> while deserializing for ") 227 << opName; 228 229 // Decode the type <id> 230 auto type = getType(words[wordIndex]); 231 if (!type) 232 return emitError(unknownLoc, "unknown type result <id>: ") 233 << words[wordIndex]; 234 resultTypes.push_back(type); 235 ++wordIndex; 236 237 // Decode the result <id> 238 if (wordIndex >= words.size()) 239 return emitError(unknownLoc, 240 "expected result <id> while deserializing for ") 241 << opName; 242 valueID = words[wordIndex]; 243 ++wordIndex; 244 } 245 246 SmallVector<Value, 4> operands; 247 SmallVector<NamedAttribute, 4> attributes; 248 249 // Decode operands 250 size_t operandIndex = 0; 251 for (; operandIndex < numOperands && wordIndex < words.size(); 252 ++operandIndex, ++wordIndex) { 253 auto arg = getValue(words[wordIndex]); 254 if (!arg) 255 return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex]; 256 operands.push_back(arg); 257 } 258 if (operandIndex != numOperands) { 259 return emitError( 260 unknownLoc, 261 "found less operands than expected when deserializing for ") 262 << opName << "; only " << operandIndex << " of " << numOperands 263 << " processed"; 264 } 265 if (wordIndex != words.size()) { 266 return emitError( 267 unknownLoc, 268 "found more operands than expected when deserializing for ") 269 << opName << "; only " << wordIndex << " of " << words.size() 270 << " processed"; 271 } 272 273 // Attach attributes from decorations 274 if (decorations.count(valueID)) { 275 auto attrs = decorations[valueID].getAttrs(); 276 attributes.append(attrs.begin(), attrs.end()); 277 } 278 279 // Create the op and update bookkeeping maps 280 Location loc = createFileLineColLoc(opBuilder); 281 OperationState opState(loc, opName); 282 opState.addOperands(operands); 283 if (hasResult) 284 opState.addTypes(resultTypes); 285 opState.addAttributes(attributes); 286 Operation *op = opBuilder.create(opState); 287 if (hasResult) 288 valueMap[valueID] = op->getResult(0); 289 290 if (op->hasTrait<OpTrait::IsTerminator>()) 291 clearDebugLine(); 292 293 return success(); 294 } 295 296 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) { 297 if (operands.size() != 2) { 298 return emitError(unknownLoc, "OpUndef instruction must have two operands"); 299 } 300 auto type = getType(operands[0]); 301 if (!type) { 302 return emitError(unknownLoc, "unknown type <id> with OpUndef instruction"); 303 } 304 undefMap[operands[1]] = type; 305 return success(); 306 } 307 308 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) { 309 if (operands.size() < 4) { 310 return emitError(unknownLoc, 311 "OpExtInst must have at least 4 operands, result type " 312 "<id>, result <id>, set <id> and instruction opcode"); 313 } 314 if (!extendedInstSets.count(operands[2])) { 315 return emitError(unknownLoc, "undefined set <id> in OpExtInst"); 316 } 317 SmallVector<uint32_t, 4> slicedOperands; 318 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); 319 slicedOperands.append(std::next(operands.begin(), 4), operands.end()); 320 return dispatchToExtensionSetAutogenDeserialization( 321 extendedInstSets[operands[2]], operands[3], slicedOperands); 322 } 323 324 namespace mlir { 325 namespace spirv { 326 327 template <> 328 LogicalResult 329 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) { 330 unsigned wordIndex = 0; 331 if (wordIndex >= words.size()) { 332 return emitError(unknownLoc, 333 "missing Execution Model specification in OpEntryPoint"); 334 } 335 auto execModel = spirv::ExecutionModelAttr::get( 336 context, static_cast<spirv::ExecutionModel>(words[wordIndex++])); 337 if (wordIndex >= words.size()) { 338 return emitError(unknownLoc, "missing <id> in OpEntryPoint"); 339 } 340 // Get the function <id> 341 auto fnID = words[wordIndex++]; 342 // Get the function name 343 auto fnName = decodeStringLiteral(words, wordIndex); 344 // Verify that the function <id> matches the fnName 345 auto parsedFunc = getFunction(fnID); 346 if (!parsedFunc) { 347 return emitError(unknownLoc, "no function matching <id> ") << fnID; 348 } 349 if (parsedFunc.getName() != fnName) { 350 // The deserializer uses "spirv_fn_<id>" as the function name if the input 351 // SPIR-V blob does not contain a name for it. We should use a more clear 352 // indication for such case rather than relying on naming details. 353 if (!parsedFunc.getName().startswith("spirv_fn_")) 354 return emitError(unknownLoc, 355 "function name mismatch between OpEntryPoint " 356 "and OpFunction with <id> ") 357 << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); 358 parsedFunc.setName(fnName); 359 } 360 SmallVector<Attribute, 4> interface; 361 while (wordIndex < words.size()) { 362 auto arg = getGlobalVariable(words[wordIndex]); 363 if (!arg) { 364 return emitError(unknownLoc, "undefined result <id> ") 365 << words[wordIndex] << " while decoding OpEntryPoint"; 366 } 367 interface.push_back(SymbolRefAttr::get(arg.getOperation())); 368 wordIndex++; 369 } 370 opBuilder.create<spirv::EntryPointOp>( 371 unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName), 372 opBuilder.getArrayAttr(interface)); 373 return success(); 374 } 375 376 template <> 377 LogicalResult 378 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) { 379 unsigned wordIndex = 0; 380 if (wordIndex >= words.size()) { 381 return emitError(unknownLoc, 382 "missing function result <id> in OpExecutionMode"); 383 } 384 // Get the function <id> to get the name of the function 385 auto fnID = words[wordIndex++]; 386 auto fn = getFunction(fnID); 387 if (!fn) { 388 return emitError(unknownLoc, "no function matching <id> ") << fnID; 389 } 390 // Get the Execution mode 391 if (wordIndex >= words.size()) { 392 return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); 393 } 394 auto execMode = spirv::ExecutionModeAttr::get( 395 context, static_cast<spirv::ExecutionMode>(words[wordIndex++])); 396 397 // Get the values 398 SmallVector<Attribute, 4> attrListElems; 399 while (wordIndex < words.size()) { 400 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); 401 } 402 auto values = opBuilder.getArrayAttr(attrListElems); 403 opBuilder.create<spirv::ExecutionModeOp>( 404 unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), 405 execMode, values); 406 return success(); 407 } 408 409 template <> 410 LogicalResult 411 Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) { 412 if (operands.size() != 3) { 413 return emitError( 414 unknownLoc, 415 "OpControlBarrier must have execution scope <id>, memory scope <id> " 416 "and memory semantics <id>"); 417 } 418 419 SmallVector<IntegerAttr, 3> argAttrs; 420 for (auto operand : operands) { 421 auto argAttr = getConstantInt(operand); 422 if (!argAttr) { 423 return emitError(unknownLoc, 424 "expected 32-bit integer constant from <id> ") 425 << operand << " for OpControlBarrier"; 426 } 427 argAttrs.push_back(argAttr); 428 } 429 430 opBuilder.create<spirv::ControlBarrierOp>( 431 unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(), 432 argAttrs[1].cast<spirv::ScopeAttr>(), 433 argAttrs[2].cast<spirv::MemorySemanticsAttr>()); 434 435 return success(); 436 } 437 438 template <> 439 LogicalResult 440 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) { 441 if (operands.size() < 3) { 442 return emitError(unknownLoc, 443 "OpFunctionCall must have at least 3 operands"); 444 } 445 446 Type resultType = getType(operands[0]); 447 if (!resultType) { 448 return emitError(unknownLoc, "undefined result type from <id> ") 449 << operands[0]; 450 } 451 452 // Use null type to mean no result type. 453 if (isVoidType(resultType)) 454 resultType = nullptr; 455 456 auto resultID = operands[1]; 457 auto functionID = operands[2]; 458 459 auto functionName = getFunctionSymbol(functionID); 460 461 SmallVector<Value, 4> arguments; 462 for (auto operand : llvm::drop_begin(operands, 3)) { 463 auto value = getValue(operand); 464 if (!value) { 465 return emitError(unknownLoc, "unknown <id> ") 466 << operand << " used by OpFunctionCall"; 467 } 468 arguments.push_back(value); 469 } 470 471 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>( 472 unknownLoc, resultType, 473 SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments); 474 475 if (resultType) 476 valueMap[resultID] = opFunctionCall.getResult(0); 477 return success(); 478 } 479 480 template <> 481 LogicalResult 482 Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) { 483 if (operands.size() != 2) { 484 return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> " 485 "and memory semantics <id>"); 486 } 487 488 SmallVector<IntegerAttr, 2> argAttrs; 489 for (auto operand : operands) { 490 auto argAttr = getConstantInt(operand); 491 if (!argAttr) { 492 return emitError(unknownLoc, 493 "expected 32-bit integer constant from <id> ") 494 << operand << " for OpMemoryBarrier"; 495 } 496 argAttrs.push_back(argAttr); 497 } 498 499 opBuilder.create<spirv::MemoryBarrierOp>( 500 unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(), 501 argAttrs[1].cast<spirv::MemorySemanticsAttr>()); 502 return success(); 503 } 504 505 template <> 506 LogicalResult 507 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) { 508 SmallVector<Type, 1> resultTypes; 509 size_t wordIndex = 0; 510 SmallVector<Value, 4> operands; 511 SmallVector<NamedAttribute, 4> attributes; 512 513 if (wordIndex < words.size()) { 514 auto arg = getValue(words[wordIndex]); 515 516 if (!arg) { 517 return emitError(unknownLoc, "unknown result <id> : ") 518 << words[wordIndex]; 519 } 520 521 operands.push_back(arg); 522 wordIndex++; 523 } 524 525 if (wordIndex < words.size()) { 526 auto arg = getValue(words[wordIndex]); 527 528 if (!arg) { 529 return emitError(unknownLoc, "unknown result <id> : ") 530 << words[wordIndex]; 531 } 532 533 operands.push_back(arg); 534 wordIndex++; 535 } 536 537 bool isAlignedAttr = false; 538 539 if (wordIndex < words.size()) { 540 auto attrValue = words[wordIndex++]; 541 attributes.push_back(opBuilder.getNamedAttr( 542 "memory_access", opBuilder.getI32IntegerAttr(attrValue))); 543 isAlignedAttr = (attrValue == 2); 544 } 545 546 if (isAlignedAttr && wordIndex < words.size()) { 547 attributes.push_back(opBuilder.getNamedAttr( 548 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); 549 } 550 551 if (wordIndex < words.size()) { 552 attributes.push_back(opBuilder.getNamedAttr( 553 "source_memory_access", 554 opBuilder.getI32IntegerAttr(words[wordIndex++]))); 555 } 556 557 if (wordIndex < words.size()) { 558 attributes.push_back(opBuilder.getNamedAttr( 559 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); 560 } 561 562 if (wordIndex != words.size()) { 563 return emitError(unknownLoc, 564 "found more operands than expected when deserializing " 565 "spirv::CopyMemoryOp, only ") 566 << wordIndex << " of " << words.size() << " processed"; 567 } 568 569 Location loc = createFileLineColLoc(opBuilder); 570 opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes); 571 572 return success(); 573 } 574 575 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and 576 // various Deserializer::processOp<...>() specializations. 577 #define GET_DESERIALIZATION_FNS 578 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" 579 580 } // namespace spirv 581 } // namespace mlir 582