1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Deserializer methods for SPIR-V binary instructions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Deserializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/Location.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/Support/Debug.h" 21 22 using namespace mlir; 23 24 #define DEBUG_TYPE "spirv-deserialization" 25 26 //===----------------------------------------------------------------------===// 27 // Utility Functions 28 //===----------------------------------------------------------------------===// 29 30 /// Extracts the opcode from the given first word of a SPIR-V instruction. 31 static inline spirv::Opcode extractOpcode(uint32_t word) { 32 return static_cast<spirv::Opcode>(word & 0xffff); 33 } 34 35 //===----------------------------------------------------------------------===// 36 // Instruction 37 //===----------------------------------------------------------------------===// 38 39 Value spirv::Deserializer::getValue(uint32_t id) { 40 if (auto constInfo = getConstant(id)) { 41 // Materialize a `spv.Constant` op at every use site. 42 return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second, 43 constInfo->first); 44 } 45 if (auto varOp = getGlobalVariable(id)) { 46 auto addressOfOp = opBuilder.create<spirv::AddressOfOp>( 47 unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation())); 48 return addressOfOp.pointer(); 49 } 50 if (auto constOp = getSpecConstant(id)) { 51 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( 52 unknownLoc, constOp.default_value().getType(), 53 SymbolRefAttr::get(constOp.getOperation())); 54 return referenceOfOp.reference(); 55 } 56 if (auto constCompositeOp = getSpecConstantComposite(id)) { 57 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( 58 unknownLoc, constCompositeOp.type(), 59 SymbolRefAttr::get(constCompositeOp.getOperation())); 60 return referenceOfOp.reference(); 61 } 62 if (auto specConstOperationInfo = getSpecConstantOperation(id)) { 63 return materializeSpecConstantOperation( 64 id, specConstOperationInfo->enclodesOpcode, 65 specConstOperationInfo->resultTypeID, 66 specConstOperationInfo->enclosedOpOperands); 67 } 68 if (auto undef = getUndefType(id)) { 69 return opBuilder.create<spirv::UndefOp>(unknownLoc, undef); 70 } 71 return valueMap.lookup(id); 72 } 73 74 LogicalResult 75 spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode, 76 ArrayRef<uint32_t> &operands, 77 Optional<spirv::Opcode> expectedOpcode) { 78 auto binarySize = binary.size(); 79 if (curOffset >= binarySize) { 80 return emitError(unknownLoc, "expected ") 81 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) 82 : "more") 83 << " instruction"; 84 } 85 86 // For each instruction, get its word count from the first word to slice it 87 // from the stream properly, and then dispatch to the instruction handler. 88 89 uint32_t wordCount = binary[curOffset] >> 16; 90 91 if (wordCount == 0) 92 return emitError(unknownLoc, "word count cannot be zero"); 93 94 uint32_t nextOffset = curOffset + wordCount; 95 if (nextOffset > binarySize) 96 return emitError(unknownLoc, "insufficient words for the last instruction"); 97 98 opcode = extractOpcode(binary[curOffset]); 99 operands = binary.slice(curOffset + 1, wordCount - 1); 100 curOffset = nextOffset; 101 return success(); 102 } 103 104 LogicalResult spirv::Deserializer::processInstruction( 105 spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) { 106 LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction " 107 << spirv::stringifyOpcode(opcode) << "\n"); 108 109 // First dispatch all the instructions whose opcode does not correspond to 110 // those that have a direct mirror in the SPIR-V dialect 111 switch (opcode) { 112 case spirv::Opcode::OpCapability: 113 return processCapability(operands); 114 case spirv::Opcode::OpExtension: 115 return processExtension(operands); 116 case spirv::Opcode::OpExtInst: 117 return processExtInst(operands); 118 case spirv::Opcode::OpExtInstImport: 119 return processExtInstImport(operands); 120 case spirv::Opcode::OpMemberName: 121 return processMemberName(operands); 122 case spirv::Opcode::OpMemoryModel: 123 return processMemoryModel(operands); 124 case spirv::Opcode::OpEntryPoint: 125 case spirv::Opcode::OpExecutionMode: 126 if (deferInstructions) { 127 deferredInstructions.emplace_back(opcode, operands); 128 return success(); 129 } 130 break; 131 case spirv::Opcode::OpVariable: 132 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) { 133 return processGlobalVariable(operands); 134 } 135 break; 136 case spirv::Opcode::OpLine: 137 return processDebugLine(operands); 138 case spirv::Opcode::OpNoLine: 139 return clearDebugLine(); 140 case spirv::Opcode::OpName: 141 return processName(operands); 142 case spirv::Opcode::OpString: 143 return processDebugString(operands); 144 case spirv::Opcode::OpModuleProcessed: 145 case spirv::Opcode::OpSource: 146 case spirv::Opcode::OpSourceContinued: 147 case spirv::Opcode::OpSourceExtension: 148 // TODO: This is debug information embedded in the binary which should be 149 // translated into the spv.module. 150 return success(); 151 case spirv::Opcode::OpTypeVoid: 152 case spirv::Opcode::OpTypeBool: 153 case spirv::Opcode::OpTypeInt: 154 case spirv::Opcode::OpTypeFloat: 155 case spirv::Opcode::OpTypeVector: 156 case spirv::Opcode::OpTypeMatrix: 157 case spirv::Opcode::OpTypeArray: 158 case spirv::Opcode::OpTypeFunction: 159 case spirv::Opcode::OpTypeImage: 160 case spirv::Opcode::OpTypeSampledImage: 161 case spirv::Opcode::OpTypeRuntimeArray: 162 case spirv::Opcode::OpTypeStruct: 163 case spirv::Opcode::OpTypePointer: 164 case spirv::Opcode::OpTypeCooperativeMatrixNV: 165 return processType(opcode, operands); 166 case spirv::Opcode::OpTypeForwardPointer: 167 return processTypeForwardPointer(operands); 168 case spirv::Opcode::OpConstant: 169 return processConstant(operands, /*isSpec=*/false); 170 case spirv::Opcode::OpSpecConstant: 171 return processConstant(operands, /*isSpec=*/true); 172 case spirv::Opcode::OpConstantComposite: 173 return processConstantComposite(operands); 174 case spirv::Opcode::OpSpecConstantComposite: 175 return processSpecConstantComposite(operands); 176 case spirv::Opcode::OpSpecConstantOp: 177 return processSpecConstantOperation(operands); 178 case spirv::Opcode::OpConstantTrue: 179 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); 180 case spirv::Opcode::OpSpecConstantTrue: 181 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); 182 case spirv::Opcode::OpConstantFalse: 183 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); 184 case spirv::Opcode::OpSpecConstantFalse: 185 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); 186 case spirv::Opcode::OpConstantNull: 187 return processConstantNull(operands); 188 case spirv::Opcode::OpDecorate: 189 return processDecoration(operands); 190 case spirv::Opcode::OpMemberDecorate: 191 return processMemberDecoration(operands); 192 case spirv::Opcode::OpFunction: 193 return processFunction(operands); 194 case spirv::Opcode::OpLabel: 195 return processLabel(operands); 196 case spirv::Opcode::OpBranch: 197 return processBranch(operands); 198 case spirv::Opcode::OpBranchConditional: 199 return processBranchConditional(operands); 200 case spirv::Opcode::OpSelectionMerge: 201 return processSelectionMerge(operands); 202 case spirv::Opcode::OpLoopMerge: 203 return processLoopMerge(operands); 204 case spirv::Opcode::OpPhi: 205 return processPhi(operands); 206 case spirv::Opcode::OpUndef: 207 return processUndef(operands); 208 default: 209 break; 210 } 211 return dispatchToAutogenDeserialization(opcode, operands); 212 } 213 214 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr( 215 ArrayRef<uint32_t> words, StringRef opName, bool hasResult, 216 unsigned numOperands) { 217 SmallVector<Type, 1> resultTypes; 218 uint32_t valueID = 0; 219 220 size_t wordIndex = 0; 221 if (hasResult) { 222 if (wordIndex >= words.size()) 223 return emitError(unknownLoc, 224 "expected result type <id> while deserializing for ") 225 << opName; 226 227 // Decode the type <id> 228 auto type = getType(words[wordIndex]); 229 if (!type) 230 return emitError(unknownLoc, "unknown type result <id>: ") 231 << words[wordIndex]; 232 resultTypes.push_back(type); 233 ++wordIndex; 234 235 // Decode the result <id> 236 if (wordIndex >= words.size()) 237 return emitError(unknownLoc, 238 "expected result <id> while deserializing for ") 239 << opName; 240 valueID = words[wordIndex]; 241 ++wordIndex; 242 } 243 244 SmallVector<Value, 4> operands; 245 SmallVector<NamedAttribute, 4> attributes; 246 247 // Decode operands 248 size_t operandIndex = 0; 249 for (; operandIndex < numOperands && wordIndex < words.size(); 250 ++operandIndex, ++wordIndex) { 251 auto arg = getValue(words[wordIndex]); 252 if (!arg) 253 return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex]; 254 operands.push_back(arg); 255 } 256 if (operandIndex != numOperands) { 257 return emitError( 258 unknownLoc, 259 "found less operands than expected when deserializing for ") 260 << opName << "; only " << operandIndex << " of " << numOperands 261 << " processed"; 262 } 263 if (wordIndex != words.size()) { 264 return emitError( 265 unknownLoc, 266 "found more operands than expected when deserializing for ") 267 << opName << "; only " << wordIndex << " of " << words.size() 268 << " processed"; 269 } 270 271 // Attach attributes from decorations 272 if (decorations.count(valueID)) { 273 auto attrs = decorations[valueID].getAttrs(); 274 attributes.append(attrs.begin(), attrs.end()); 275 } 276 277 // Create the op and update bookkeeping maps 278 Location loc = createFileLineColLoc(opBuilder); 279 OperationState opState(loc, opName); 280 opState.addOperands(operands); 281 if (hasResult) 282 opState.addTypes(resultTypes); 283 opState.addAttributes(attributes); 284 Operation *op = opBuilder.createOperation(opState); 285 if (hasResult) 286 valueMap[valueID] = op->getResult(0); 287 288 if (op->hasTrait<OpTrait::IsTerminator>()) 289 (void)clearDebugLine(); 290 291 return success(); 292 } 293 294 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) { 295 if (operands.size() != 2) { 296 return emitError(unknownLoc, "OpUndef instruction must have two operands"); 297 } 298 auto type = getType(operands[0]); 299 if (!type) { 300 return emitError(unknownLoc, "unknown type <id> with OpUndef instruction"); 301 } 302 undefMap[operands[1]] = type; 303 return success(); 304 } 305 306 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) { 307 if (operands.size() < 4) { 308 return emitError(unknownLoc, 309 "OpExtInst must have at least 4 operands, result type " 310 "<id>, result <id>, set <id> and instruction opcode"); 311 } 312 if (!extendedInstSets.count(operands[2])) { 313 return emitError(unknownLoc, "undefined set <id> in OpExtInst"); 314 } 315 SmallVector<uint32_t, 4> slicedOperands; 316 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); 317 slicedOperands.append(std::next(operands.begin(), 4), operands.end()); 318 return dispatchToExtensionSetAutogenDeserialization( 319 extendedInstSets[operands[2]], operands[3], slicedOperands); 320 } 321 322 namespace mlir { 323 namespace spirv { 324 325 template <> 326 LogicalResult 327 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) { 328 unsigned wordIndex = 0; 329 if (wordIndex >= words.size()) { 330 return emitError(unknownLoc, 331 "missing Execution Model specification in OpEntryPoint"); 332 } 333 auto execModel = spirv::ExecutionModelAttr::get( 334 context, static_cast<spirv::ExecutionModel>(words[wordIndex++])); 335 if (wordIndex >= words.size()) { 336 return emitError(unknownLoc, "missing <id> in OpEntryPoint"); 337 } 338 // Get the function <id> 339 auto fnID = words[wordIndex++]; 340 // Get the function name 341 auto fnName = decodeStringLiteral(words, wordIndex); 342 // Verify that the function <id> matches the fnName 343 auto parsedFunc = getFunction(fnID); 344 if (!parsedFunc) { 345 return emitError(unknownLoc, "no function matching <id> ") << fnID; 346 } 347 if (parsedFunc.getName() != fnName) { 348 return emitError(unknownLoc, "function name mismatch between OpEntryPoint " 349 "and OpFunction with <id> ") 350 << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); 351 } 352 SmallVector<Attribute, 4> interface; 353 while (wordIndex < words.size()) { 354 auto arg = getGlobalVariable(words[wordIndex]); 355 if (!arg) { 356 return emitError(unknownLoc, "undefined result <id> ") 357 << words[wordIndex] << " while decoding OpEntryPoint"; 358 } 359 interface.push_back(SymbolRefAttr::get(arg.getOperation())); 360 wordIndex++; 361 } 362 opBuilder.create<spirv::EntryPointOp>( 363 unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName), 364 opBuilder.getArrayAttr(interface)); 365 return success(); 366 } 367 368 template <> 369 LogicalResult 370 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) { 371 unsigned wordIndex = 0; 372 if (wordIndex >= words.size()) { 373 return emitError(unknownLoc, 374 "missing function result <id> in OpExecutionMode"); 375 } 376 // Get the function <id> to get the name of the function 377 auto fnID = words[wordIndex++]; 378 auto fn = getFunction(fnID); 379 if (!fn) { 380 return emitError(unknownLoc, "no function matching <id> ") << fnID; 381 } 382 // Get the Execution mode 383 if (wordIndex >= words.size()) { 384 return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); 385 } 386 auto execMode = spirv::ExecutionModeAttr::get( 387 context, static_cast<spirv::ExecutionMode>(words[wordIndex++])); 388 389 // Get the values 390 SmallVector<Attribute, 4> attrListElems; 391 while (wordIndex < words.size()) { 392 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); 393 } 394 auto values = opBuilder.getArrayAttr(attrListElems); 395 opBuilder.create<spirv::ExecutionModeOp>( 396 unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()), 397 execMode, values); 398 return success(); 399 } 400 401 template <> 402 LogicalResult 403 Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) { 404 if (operands.size() != 3) { 405 return emitError( 406 unknownLoc, 407 "OpControlBarrier must have execution scope <id>, memory scope <id> " 408 "and memory semantics <id>"); 409 } 410 411 SmallVector<IntegerAttr, 3> argAttrs; 412 for (auto operand : operands) { 413 auto argAttr = getConstantInt(operand); 414 if (!argAttr) { 415 return emitError(unknownLoc, 416 "expected 32-bit integer constant from <id> ") 417 << operand << " for OpControlBarrier"; 418 } 419 argAttrs.push_back(argAttr); 420 } 421 422 opBuilder.create<spirv::ControlBarrierOp>( 423 unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(), 424 argAttrs[1].cast<spirv::ScopeAttr>(), 425 argAttrs[2].cast<spirv::MemorySemanticsAttr>()); 426 427 return success(); 428 } 429 430 template <> 431 LogicalResult 432 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) { 433 if (operands.size() < 3) { 434 return emitError(unknownLoc, 435 "OpFunctionCall must have at least 3 operands"); 436 } 437 438 Type resultType = getType(operands[0]); 439 if (!resultType) { 440 return emitError(unknownLoc, "undefined result type from <id> ") 441 << operands[0]; 442 } 443 444 // Use null type to mean no result type. 445 if (isVoidType(resultType)) 446 resultType = nullptr; 447 448 auto resultID = operands[1]; 449 auto functionID = operands[2]; 450 451 auto functionName = getFunctionSymbol(functionID); 452 453 SmallVector<Value, 4> arguments; 454 for (auto operand : llvm::drop_begin(operands, 3)) { 455 auto value = getValue(operand); 456 if (!value) { 457 return emitError(unknownLoc, "unknown <id> ") 458 << operand << " used by OpFunctionCall"; 459 } 460 arguments.push_back(value); 461 } 462 463 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>( 464 unknownLoc, resultType, 465 SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments); 466 467 if (resultType) 468 valueMap[resultID] = opFunctionCall.getResult(0); 469 return success(); 470 } 471 472 template <> 473 LogicalResult 474 Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) { 475 if (operands.size() != 2) { 476 return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> " 477 "and memory semantics <id>"); 478 } 479 480 SmallVector<IntegerAttr, 2> argAttrs; 481 for (auto operand : operands) { 482 auto argAttr = getConstantInt(operand); 483 if (!argAttr) { 484 return emitError(unknownLoc, 485 "expected 32-bit integer constant from <id> ") 486 << operand << " for OpMemoryBarrier"; 487 } 488 argAttrs.push_back(argAttr); 489 } 490 491 opBuilder.create<spirv::MemoryBarrierOp>( 492 unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(), 493 argAttrs[1].cast<spirv::MemorySemanticsAttr>()); 494 return success(); 495 } 496 497 template <> 498 LogicalResult 499 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) { 500 SmallVector<Type, 1> resultTypes; 501 size_t wordIndex = 0; 502 SmallVector<Value, 4> operands; 503 SmallVector<NamedAttribute, 4> attributes; 504 505 if (wordIndex < words.size()) { 506 auto arg = getValue(words[wordIndex]); 507 508 if (!arg) { 509 return emitError(unknownLoc, "unknown result <id> : ") 510 << words[wordIndex]; 511 } 512 513 operands.push_back(arg); 514 wordIndex++; 515 } 516 517 if (wordIndex < words.size()) { 518 auto arg = getValue(words[wordIndex]); 519 520 if (!arg) { 521 return emitError(unknownLoc, "unknown result <id> : ") 522 << words[wordIndex]; 523 } 524 525 operands.push_back(arg); 526 wordIndex++; 527 } 528 529 bool isAlignedAttr = false; 530 531 if (wordIndex < words.size()) { 532 auto attrValue = words[wordIndex++]; 533 attributes.push_back(opBuilder.getNamedAttr( 534 "memory_access", opBuilder.getI32IntegerAttr(attrValue))); 535 isAlignedAttr = (attrValue == 2); 536 } 537 538 if (isAlignedAttr && wordIndex < words.size()) { 539 attributes.push_back(opBuilder.getNamedAttr( 540 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); 541 } 542 543 if (wordIndex < words.size()) { 544 attributes.push_back(opBuilder.getNamedAttr( 545 "source_memory_access", 546 opBuilder.getI32IntegerAttr(words[wordIndex++]))); 547 } 548 549 if (wordIndex < words.size()) { 550 attributes.push_back(opBuilder.getNamedAttr( 551 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); 552 } 553 554 if (wordIndex != words.size()) { 555 return emitError(unknownLoc, 556 "found more operands than expected when deserializing " 557 "spirv::CopyMemoryOp, only ") 558 << wordIndex << " of " << words.size() << " processed"; 559 } 560 561 Location loc = createFileLineColLoc(opBuilder); 562 opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes); 563 564 return success(); 565 } 566 567 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and 568 // various Deserializer::processOp<...>() specializations. 569 #define GET_DESERIALIZATION_FNS 570 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" 571 572 } // namespace spirv 573 } // namespace mlir 574