1 //===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the Deserializer methods for SPIR-V binary instructions. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "Deserializer.h" 14 15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/Location.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "llvm/ADT/SmallVector.h" 20 #include "llvm/Support/Debug.h" 21 22 using namespace mlir; 23 24 #define DEBUG_TYPE "spirv-deserialization" 25 26 //===----------------------------------------------------------------------===// 27 // Utility Functions 28 //===----------------------------------------------------------------------===// 29 30 /// Extracts the opcode from the given first word of a SPIR-V instruction. 31 static inline spirv::Opcode extractOpcode(uint32_t word) { 32 return static_cast<spirv::Opcode>(word & 0xffff); 33 } 34 35 //===----------------------------------------------------------------------===// 36 // Instruction 37 //===----------------------------------------------------------------------===// 38 39 Value spirv::Deserializer::getValue(uint32_t id) { 40 if (auto constInfo = getConstant(id)) { 41 // Materialize a `spv.constant` op at every use site. 42 return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second, 43 constInfo->first); 44 } 45 if (auto varOp = getGlobalVariable(id)) { 46 auto addressOfOp = opBuilder.create<spirv::AddressOfOp>( 47 unknownLoc, varOp.type(), 48 opBuilder.getSymbolRefAttr(varOp.getOperation())); 49 return addressOfOp.pointer(); 50 } 51 if (auto constOp = getSpecConstant(id)) { 52 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( 53 unknownLoc, constOp.default_value().getType(), 54 opBuilder.getSymbolRefAttr(constOp.getOperation())); 55 return referenceOfOp.reference(); 56 } 57 if (auto constCompositeOp = getSpecConstantComposite(id)) { 58 auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>( 59 unknownLoc, constCompositeOp.type(), 60 opBuilder.getSymbolRefAttr(constCompositeOp.getOperation())); 61 return referenceOfOp.reference(); 62 } 63 if (auto specConstOperationInfo = getSpecConstantOperation(id)) { 64 return materializeSpecConstantOperation( 65 id, specConstOperationInfo->enclodesOpcode, 66 specConstOperationInfo->resultTypeID, 67 specConstOperationInfo->enclosedOpOperands); 68 } 69 if (auto undef = getUndefType(id)) { 70 return opBuilder.create<spirv::UndefOp>(unknownLoc, undef); 71 } 72 return valueMap.lookup(id); 73 } 74 75 LogicalResult 76 spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode, 77 ArrayRef<uint32_t> &operands, 78 Optional<spirv::Opcode> expectedOpcode) { 79 auto binarySize = binary.size(); 80 if (curOffset >= binarySize) { 81 return emitError(unknownLoc, "expected ") 82 << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode) 83 : "more") 84 << " instruction"; 85 } 86 87 // For each instruction, get its word count from the first word to slice it 88 // from the stream properly, and then dispatch to the instruction handler. 89 90 uint32_t wordCount = binary[curOffset] >> 16; 91 92 if (wordCount == 0) 93 return emitError(unknownLoc, "word count cannot be zero"); 94 95 uint32_t nextOffset = curOffset + wordCount; 96 if (nextOffset > binarySize) 97 return emitError(unknownLoc, "insufficient words for the last instruction"); 98 99 opcode = extractOpcode(binary[curOffset]); 100 operands = binary.slice(curOffset + 1, wordCount - 1); 101 curOffset = nextOffset; 102 return success(); 103 } 104 105 LogicalResult spirv::Deserializer::processInstruction( 106 spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) { 107 LLVM_DEBUG(llvm::dbgs() << "[inst] processing instruction " 108 << spirv::stringifyOpcode(opcode) << "\n"); 109 110 // First dispatch all the instructions whose opcode does not correspond to 111 // those that have a direct mirror in the SPIR-V dialect 112 switch (opcode) { 113 case spirv::Opcode::OpCapability: 114 return processCapability(operands); 115 case spirv::Opcode::OpExtension: 116 return processExtension(operands); 117 case spirv::Opcode::OpExtInst: 118 return processExtInst(operands); 119 case spirv::Opcode::OpExtInstImport: 120 return processExtInstImport(operands); 121 case spirv::Opcode::OpMemberName: 122 return processMemberName(operands); 123 case spirv::Opcode::OpMemoryModel: 124 return processMemoryModel(operands); 125 case spirv::Opcode::OpEntryPoint: 126 case spirv::Opcode::OpExecutionMode: 127 if (deferInstructions) { 128 deferredInstructions.emplace_back(opcode, operands); 129 return success(); 130 } 131 break; 132 case spirv::Opcode::OpVariable: 133 if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) { 134 return processGlobalVariable(operands); 135 } 136 break; 137 case spirv::Opcode::OpLine: 138 return processDebugLine(operands); 139 case spirv::Opcode::OpNoLine: 140 return clearDebugLine(); 141 case spirv::Opcode::OpName: 142 return processName(operands); 143 case spirv::Opcode::OpString: 144 return processDebugString(operands); 145 case spirv::Opcode::OpModuleProcessed: 146 case spirv::Opcode::OpSource: 147 case spirv::Opcode::OpSourceContinued: 148 case spirv::Opcode::OpSourceExtension: 149 // TODO: This is debug information embedded in the binary which should be 150 // translated into the spv.module. 151 return success(); 152 case spirv::Opcode::OpTypeVoid: 153 case spirv::Opcode::OpTypeBool: 154 case spirv::Opcode::OpTypeInt: 155 case spirv::Opcode::OpTypeFloat: 156 case spirv::Opcode::OpTypeVector: 157 case spirv::Opcode::OpTypeMatrix: 158 case spirv::Opcode::OpTypeArray: 159 case spirv::Opcode::OpTypeFunction: 160 case spirv::Opcode::OpTypeRuntimeArray: 161 case spirv::Opcode::OpTypeStruct: 162 case spirv::Opcode::OpTypePointer: 163 case spirv::Opcode::OpTypeCooperativeMatrixNV: 164 return processType(opcode, operands); 165 case spirv::Opcode::OpTypeForwardPointer: 166 return processTypeForwardPointer(operands); 167 case spirv::Opcode::OpConstant: 168 return processConstant(operands, /*isSpec=*/false); 169 case spirv::Opcode::OpSpecConstant: 170 return processConstant(operands, /*isSpec=*/true); 171 case spirv::Opcode::OpConstantComposite: 172 return processConstantComposite(operands); 173 case spirv::Opcode::OpSpecConstantComposite: 174 return processSpecConstantComposite(operands); 175 case spirv::Opcode::OpSpecConstantOperation: 176 return processSpecConstantOperation(operands); 177 case spirv::Opcode::OpConstantTrue: 178 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false); 179 case spirv::Opcode::OpSpecConstantTrue: 180 return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true); 181 case spirv::Opcode::OpConstantFalse: 182 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false); 183 case spirv::Opcode::OpSpecConstantFalse: 184 return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true); 185 case spirv::Opcode::OpConstantNull: 186 return processConstantNull(operands); 187 case spirv::Opcode::OpDecorate: 188 return processDecoration(operands); 189 case spirv::Opcode::OpMemberDecorate: 190 return processMemberDecoration(operands); 191 case spirv::Opcode::OpFunction: 192 return processFunction(operands); 193 case spirv::Opcode::OpLabel: 194 return processLabel(operands); 195 case spirv::Opcode::OpBranch: 196 return processBranch(operands); 197 case spirv::Opcode::OpBranchConditional: 198 return processBranchConditional(operands); 199 case spirv::Opcode::OpSelectionMerge: 200 return processSelectionMerge(operands); 201 case spirv::Opcode::OpLoopMerge: 202 return processLoopMerge(operands); 203 case spirv::Opcode::OpPhi: 204 return processPhi(operands); 205 case spirv::Opcode::OpUndef: 206 return processUndef(operands); 207 default: 208 break; 209 } 210 return dispatchToAutogenDeserialization(opcode, operands); 211 } 212 213 LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr( 214 ArrayRef<uint32_t> words, StringRef opName, bool hasResult, 215 unsigned numOperands) { 216 SmallVector<Type, 1> resultTypes; 217 uint32_t valueID = 0; 218 219 size_t wordIndex = 0; 220 if (hasResult) { 221 if (wordIndex >= words.size()) 222 return emitError(unknownLoc, 223 "expected result type <id> while deserializing for ") 224 << opName; 225 226 // Decode the type <id> 227 auto type = getType(words[wordIndex]); 228 if (!type) 229 return emitError(unknownLoc, "unknown type result <id>: ") 230 << words[wordIndex]; 231 resultTypes.push_back(type); 232 ++wordIndex; 233 234 // Decode the result <id> 235 if (wordIndex >= words.size()) 236 return emitError(unknownLoc, 237 "expected result <id> while deserializing for ") 238 << opName; 239 valueID = words[wordIndex]; 240 ++wordIndex; 241 } 242 243 SmallVector<Value, 4> operands; 244 SmallVector<NamedAttribute, 4> attributes; 245 246 // Decode operands 247 size_t operandIndex = 0; 248 for (; operandIndex < numOperands && wordIndex < words.size(); 249 ++operandIndex, ++wordIndex) { 250 auto arg = getValue(words[wordIndex]); 251 if (!arg) 252 return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex]; 253 operands.push_back(arg); 254 } 255 if (operandIndex != numOperands) { 256 return emitError( 257 unknownLoc, 258 "found less operands than expected when deserializing for ") 259 << opName << "; only " << operandIndex << " of " << numOperands 260 << " processed"; 261 } 262 if (wordIndex != words.size()) { 263 return emitError( 264 unknownLoc, 265 "found more operands than expected when deserializing for ") 266 << opName << "; only " << wordIndex << " of " << words.size() 267 << " processed"; 268 } 269 270 // Attach attributes from decorations 271 if (decorations.count(valueID)) { 272 auto attrs = decorations[valueID].getAttrs(); 273 attributes.append(attrs.begin(), attrs.end()); 274 } 275 276 // Create the op and update bookkeeping maps 277 Location loc = createFileLineColLoc(opBuilder); 278 OperationState opState(loc, opName); 279 opState.addOperands(operands); 280 if (hasResult) 281 opState.addTypes(resultTypes); 282 opState.addAttributes(attributes); 283 Operation *op = opBuilder.createOperation(opState); 284 if (hasResult) 285 valueMap[valueID] = op->getResult(0); 286 287 if (op->hasTrait<OpTrait::IsTerminator>()) 288 clearDebugLine(); 289 290 return success(); 291 } 292 293 LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) { 294 if (operands.size() != 2) { 295 return emitError(unknownLoc, "OpUndef instruction must have two operands"); 296 } 297 auto type = getType(operands[0]); 298 if (!type) { 299 return emitError(unknownLoc, "unknown type <id> with OpUndef instruction"); 300 } 301 undefMap[operands[1]] = type; 302 return success(); 303 } 304 305 LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) { 306 if (operands.size() < 4) { 307 return emitError(unknownLoc, 308 "OpExtInst must have at least 4 operands, result type " 309 "<id>, result <id>, set <id> and instruction opcode"); 310 } 311 if (!extendedInstSets.count(operands[2])) { 312 return emitError(unknownLoc, "undefined set <id> in OpExtInst"); 313 } 314 SmallVector<uint32_t, 4> slicedOperands; 315 slicedOperands.append(operands.begin(), std::next(operands.begin(), 2)); 316 slicedOperands.append(std::next(operands.begin(), 4), operands.end()); 317 return dispatchToExtensionSetAutogenDeserialization( 318 extendedInstSets[operands[2]], operands[3], slicedOperands); 319 } 320 321 namespace mlir { 322 namespace spirv { 323 324 template <> 325 LogicalResult 326 Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) { 327 unsigned wordIndex = 0; 328 if (wordIndex >= words.size()) { 329 return emitError(unknownLoc, 330 "missing Execution Model specification in OpEntryPoint"); 331 } 332 auto execModel = opBuilder.getI32IntegerAttr(words[wordIndex++]); 333 if (wordIndex >= words.size()) { 334 return emitError(unknownLoc, "missing <id> in OpEntryPoint"); 335 } 336 // Get the function <id> 337 auto fnID = words[wordIndex++]; 338 // Get the function name 339 auto fnName = decodeStringLiteral(words, wordIndex); 340 // Verify that the function <id> matches the fnName 341 auto parsedFunc = getFunction(fnID); 342 if (!parsedFunc) { 343 return emitError(unknownLoc, "no function matching <id> ") << fnID; 344 } 345 if (parsedFunc.getName() != fnName) { 346 return emitError(unknownLoc, "function name mismatch between OpEntryPoint " 347 "and OpFunction with <id> ") 348 << fnID << ": " << fnName << " vs. " << parsedFunc.getName(); 349 } 350 SmallVector<Attribute, 4> interface; 351 while (wordIndex < words.size()) { 352 auto arg = getGlobalVariable(words[wordIndex]); 353 if (!arg) { 354 return emitError(unknownLoc, "undefined result <id> ") 355 << words[wordIndex] << " while decoding OpEntryPoint"; 356 } 357 interface.push_back(opBuilder.getSymbolRefAttr(arg.getOperation())); 358 wordIndex++; 359 } 360 opBuilder.create<spirv::EntryPointOp>(unknownLoc, execModel, 361 opBuilder.getSymbolRefAttr(fnName), 362 opBuilder.getArrayAttr(interface)); 363 return success(); 364 } 365 366 template <> 367 LogicalResult 368 Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) { 369 unsigned wordIndex = 0; 370 if (wordIndex >= words.size()) { 371 return emitError(unknownLoc, 372 "missing function result <id> in OpExecutionMode"); 373 } 374 // Get the function <id> to get the name of the function 375 auto fnID = words[wordIndex++]; 376 auto fn = getFunction(fnID); 377 if (!fn) { 378 return emitError(unknownLoc, "no function matching <id> ") << fnID; 379 } 380 // Get the Execution mode 381 if (wordIndex >= words.size()) { 382 return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode"); 383 } 384 auto execMode = opBuilder.getI32IntegerAttr(words[wordIndex++]); 385 386 // Get the values 387 SmallVector<Attribute, 4> attrListElems; 388 while (wordIndex < words.size()) { 389 attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++])); 390 } 391 auto values = opBuilder.getArrayAttr(attrListElems); 392 opBuilder.create<spirv::ExecutionModeOp>( 393 unknownLoc, opBuilder.getSymbolRefAttr(fn.getName()), execMode, values); 394 return success(); 395 } 396 397 template <> 398 LogicalResult 399 Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) { 400 if (operands.size() != 3) { 401 return emitError( 402 unknownLoc, 403 "OpControlBarrier must have execution scope <id>, memory scope <id> " 404 "and memory semantics <id>"); 405 } 406 407 SmallVector<IntegerAttr, 3> argAttrs; 408 for (auto operand : operands) { 409 auto argAttr = getConstantInt(operand); 410 if (!argAttr) { 411 return emitError(unknownLoc, 412 "expected 32-bit integer constant from <id> ") 413 << operand << " for OpControlBarrier"; 414 } 415 argAttrs.push_back(argAttr); 416 } 417 418 opBuilder.create<spirv::ControlBarrierOp>(unknownLoc, argAttrs[0], 419 argAttrs[1], argAttrs[2]); 420 return success(); 421 } 422 423 template <> 424 LogicalResult 425 Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) { 426 if (operands.size() < 3) { 427 return emitError(unknownLoc, 428 "OpFunctionCall must have at least 3 operands"); 429 } 430 431 Type resultType = getType(operands[0]); 432 if (!resultType) { 433 return emitError(unknownLoc, "undefined result type from <id> ") 434 << operands[0]; 435 } 436 437 // Use null type to mean no result type. 438 if (isVoidType(resultType)) 439 resultType = nullptr; 440 441 auto resultID = operands[1]; 442 auto functionID = operands[2]; 443 444 auto functionName = getFunctionSymbol(functionID); 445 446 SmallVector<Value, 4> arguments; 447 for (auto operand : llvm::drop_begin(operands, 3)) { 448 auto value = getValue(operand); 449 if (!value) { 450 return emitError(unknownLoc, "unknown <id> ") 451 << operand << " used by OpFunctionCall"; 452 } 453 arguments.push_back(value); 454 } 455 456 auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>( 457 unknownLoc, resultType, opBuilder.getSymbolRefAttr(functionName), 458 arguments); 459 460 if (resultType) 461 valueMap[resultID] = opFunctionCall.getResult(0); 462 return success(); 463 } 464 465 template <> 466 LogicalResult 467 Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) { 468 if (operands.size() != 2) { 469 return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> " 470 "and memory semantics <id>"); 471 } 472 473 SmallVector<IntegerAttr, 2> argAttrs; 474 for (auto operand : operands) { 475 auto argAttr = getConstantInt(operand); 476 if (!argAttr) { 477 return emitError(unknownLoc, 478 "expected 32-bit integer constant from <id> ") 479 << operand << " for OpMemoryBarrier"; 480 } 481 argAttrs.push_back(argAttr); 482 } 483 484 opBuilder.create<spirv::MemoryBarrierOp>(unknownLoc, argAttrs[0], 485 argAttrs[1]); 486 return success(); 487 } 488 489 template <> 490 LogicalResult 491 Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) { 492 SmallVector<Type, 1> resultTypes; 493 size_t wordIndex = 0; 494 SmallVector<Value, 4> operands; 495 SmallVector<NamedAttribute, 4> attributes; 496 497 if (wordIndex < words.size()) { 498 auto arg = getValue(words[wordIndex]); 499 500 if (!arg) { 501 return emitError(unknownLoc, "unknown result <id> : ") 502 << words[wordIndex]; 503 } 504 505 operands.push_back(arg); 506 wordIndex++; 507 } 508 509 if (wordIndex < words.size()) { 510 auto arg = getValue(words[wordIndex]); 511 512 if (!arg) { 513 return emitError(unknownLoc, "unknown result <id> : ") 514 << words[wordIndex]; 515 } 516 517 operands.push_back(arg); 518 wordIndex++; 519 } 520 521 bool isAlignedAttr = false; 522 523 if (wordIndex < words.size()) { 524 auto attrValue = words[wordIndex++]; 525 attributes.push_back(opBuilder.getNamedAttr( 526 "memory_access", opBuilder.getI32IntegerAttr(attrValue))); 527 isAlignedAttr = (attrValue == 2); 528 } 529 530 if (isAlignedAttr && wordIndex < words.size()) { 531 attributes.push_back(opBuilder.getNamedAttr( 532 "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); 533 } 534 535 if (wordIndex < words.size()) { 536 attributes.push_back(opBuilder.getNamedAttr( 537 "source_memory_access", 538 opBuilder.getI32IntegerAttr(words[wordIndex++]))); 539 } 540 541 if (wordIndex < words.size()) { 542 attributes.push_back(opBuilder.getNamedAttr( 543 "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++]))); 544 } 545 546 if (wordIndex != words.size()) { 547 return emitError(unknownLoc, 548 "found more operands than expected when deserializing " 549 "spirv::CopyMemoryOp, only ") 550 << wordIndex << " of " << words.size() << " processed"; 551 } 552 553 Location loc = createFileLineColLoc(opBuilder); 554 opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes); 555 556 return success(); 557 } 558 559 // Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and 560 // various Deserializer::processOp<...>() specializations. 561 #define GET_DESERIALIZATION_FNS 562 #include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc" 563 564 } // namespace spirv 565 } // namespace mlir 566