//===- DeserializeOps.cpp - MLIR SPIR-V Deserialization (Ops) -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the Deserializer methods for SPIR-V binary instructions.
//
//===----------------------------------------------------------------------===//

#include "Deserializer.h"

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Location.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Support/Debug.h"

using namespace mlir;

#define DEBUG_TYPE "spirv-deserialization"

//===----------------------------------------------------------------------===//
// Utility Functions
//===----------------------------------------------------------------------===//

/// Extracts the opcode from the given first word of a SPIR-V instruction.
static inline spirv::Opcode extractOpcode(uint32_t word) {
  return static_cast<spirv::Opcode>(word & 0xffff);
}

//===----------------------------------------------------------------------===//
// Instruction
//===----------------------------------------------------------------------===//

Value spirv::Deserializer::getValue(uint32_t id) {
  if (auto constInfo = getConstant(id)) {
    // Materialize a `spv.Constant` op at every use site.
    return opBuilder.create<spirv::ConstantOp>(unknownLoc, constInfo->second,
                                               constInfo->first);
  }
  if (auto varOp = getGlobalVariable(id)) {
    auto addressOfOp = opBuilder.create<spirv::AddressOfOp>(
        unknownLoc, varOp.type(), SymbolRefAttr::get(varOp.getOperation()));
    return addressOfOp.pointer();
  }
  if (auto constOp = getSpecConstant(id)) {
    auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
        unknownLoc, constOp.default_value().getType(),
        SymbolRefAttr::get(constOp.getOperation()));
    return referenceOfOp.reference();
  }
  if (auto constCompositeOp = getSpecConstantComposite(id)) {
    auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
        unknownLoc, constCompositeOp.type(),
        SymbolRefAttr::get(constCompositeOp.getOperation()));
    return referenceOfOp.reference();
  }
  if (auto specConstOperationInfo = getSpecConstantOperation(id)) {
    return materializeSpecConstantOperation(
        id, specConstOperationInfo->enclodesOpcode,
        specConstOperationInfo->resultTypeID,
        specConstOperationInfo->enclosedOpOperands);
  }
  if (auto undef = getUndefType(id)) {
    return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
  }
  return valueMap.lookup(id);
}

LogicalResult
spirv::Deserializer::sliceInstruction(spirv::Opcode &opcode,
                                      ArrayRef<uint32_t> &operands,
                                      Optional<spirv::Opcode> expectedOpcode) {
  auto binarySize = binary.size();
  if (curOffset >= binarySize) {
    return emitError(unknownLoc, "expected ")
           << (expectedOpcode ? spirv::stringifyOpcode(*expectedOpcode)
                              : "more")
           << " instruction";
  }

  // For each instruction, get its word count from the first word to slice it
  // from the stream properly, and then dispatch to the instruction handler.

  uint32_t wordCount = binary[curOffset] >> 16;

  if (wordCount == 0)
    return emitError(unknownLoc, "word count cannot be zero");

  uint32_t nextOffset = curOffset + wordCount;
  if (nextOffset > binarySize)
    return emitError(unknownLoc, "insufficient words for the last instruction");

  opcode = extractOpcode(binary[curOffset]);
  operands = binary.slice(curOffset + 1, wordCount - 1);
  curOffset = nextOffset;
  return success();
}

LogicalResult spirv::Deserializer::processInstruction(
    spirv::Opcode opcode, ArrayRef<uint32_t> operands, bool deferInstructions) {
  LLVM_DEBUG(logger.startLine() << "[inst] processing instruction "
                                << spirv::stringifyOpcode(opcode) << "\n");

  // First dispatch all the instructions whose opcode does not correspond to
  // those that have a direct mirror in the SPIR-V dialect
  switch (opcode) {
  case spirv::Opcode::OpCapability:
    return processCapability(operands);
  case spirv::Opcode::OpExtension:
    return processExtension(operands);
  case spirv::Opcode::OpExtInst:
    return processExtInst(operands);
  case spirv::Opcode::OpExtInstImport:
    return processExtInstImport(operands);
  case spirv::Opcode::OpMemberName:
    return processMemberName(operands);
  case spirv::Opcode::OpMemoryModel:
    return processMemoryModel(operands);
  case spirv::Opcode::OpEntryPoint:
  case spirv::Opcode::OpExecutionMode:
    if (deferInstructions) {
      deferredInstructions.emplace_back(opcode, operands);
      return success();
    }
    break;
  case spirv::Opcode::OpVariable:
    if (isa<spirv::ModuleOp>(opBuilder.getBlock()->getParentOp())) {
      return processGlobalVariable(operands);
    }
    break;
  case spirv::Opcode::OpLine:
    return processDebugLine(operands);
  case spirv::Opcode::OpNoLine:
    clearDebugLine();
    return success();
  case spirv::Opcode::OpName:
    return processName(operands);
  case spirv::Opcode::OpString:
    return processDebugString(operands);
  case spirv::Opcode::OpModuleProcessed:
  case spirv::Opcode::OpSource:
  case spirv::Opcode::OpSourceContinued:
  case spirv::Opcode::OpSourceExtension:
    // TODO: This is debug information embedded in the binary which should be
    // translated into the spv.module.
    return success();
  case spirv::Opcode::OpTypeVoid:
  case spirv::Opcode::OpTypeBool:
  case spirv::Opcode::OpTypeInt:
  case spirv::Opcode::OpTypeFloat:
  case spirv::Opcode::OpTypeVector:
  case spirv::Opcode::OpTypeMatrix:
  case spirv::Opcode::OpTypeArray:
  case spirv::Opcode::OpTypeFunction:
  case spirv::Opcode::OpTypeImage:
  case spirv::Opcode::OpTypeSampledImage:
  case spirv::Opcode::OpTypeRuntimeArray:
  case spirv::Opcode::OpTypeStruct:
  case spirv::Opcode::OpTypePointer:
  case spirv::Opcode::OpTypeCooperativeMatrixNV:
    return processType(opcode, operands);
  case spirv::Opcode::OpTypeForwardPointer:
    return processTypeForwardPointer(operands);
  case spirv::Opcode::OpConstant:
    return processConstant(operands, /*isSpec=*/false);
  case spirv::Opcode::OpSpecConstant:
    return processConstant(operands, /*isSpec=*/true);
  case spirv::Opcode::OpConstantComposite:
    return processConstantComposite(operands);
  case spirv::Opcode::OpSpecConstantComposite:
    return processSpecConstantComposite(operands);
  case spirv::Opcode::OpSpecConstantOp:
    return processSpecConstantOperation(operands);
  case spirv::Opcode::OpConstantTrue:
    return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/false);
  case spirv::Opcode::OpSpecConstantTrue:
    return processConstantBool(/*isTrue=*/true, operands, /*isSpec=*/true);
  case spirv::Opcode::OpConstantFalse:
    return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/false);
  case spirv::Opcode::OpSpecConstantFalse:
    return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
  case spirv::Opcode::OpConstantNull:
    return processConstantNull(operands);
  case spirv::Opcode::OpDecorate:
    return processDecoration(operands);
  case spirv::Opcode::OpMemberDecorate:
    return processMemberDecoration(operands);
  case spirv::Opcode::OpFunction:
    return processFunction(operands);
  case spirv::Opcode::OpLabel:
    return processLabel(operands);
  case spirv::Opcode::OpBranch:
    return processBranch(operands);
  case spirv::Opcode::OpBranchConditional:
    return processBranchConditional(operands);
  case spirv::Opcode::OpSelectionMerge:
    return processSelectionMerge(operands);
  case spirv::Opcode::OpLoopMerge:
    return processLoopMerge(operands);
  case spirv::Opcode::OpPhi:
    return processPhi(operands);
  case spirv::Opcode::OpUndef:
    return processUndef(operands);
  default:
    break;
  }
  return dispatchToAutogenDeserialization(opcode, operands);
}

LogicalResult spirv::Deserializer::processOpWithoutGrammarAttr(
    ArrayRef<uint32_t> words, StringRef opName, bool hasResult,
    unsigned numOperands) {
  SmallVector<Type, 1> resultTypes;
  uint32_t valueID = 0;

  size_t wordIndex = 0;
  if (hasResult) {
    if (wordIndex >= words.size())
      return emitError(unknownLoc,
                       "expected result type <id> while deserializing for ")
             << opName;

    // Decode the type <id>
    auto type = getType(words[wordIndex]);
    if (!type)
      return emitError(unknownLoc, "unknown type result <id>: ")
             << words[wordIndex];
    resultTypes.push_back(type);
    ++wordIndex;

    // Decode the result <id>
    if (wordIndex >= words.size())
      return emitError(unknownLoc,
                       "expected result <id> while deserializing for ")
             << opName;
    valueID = words[wordIndex];
    ++wordIndex;
  }

  SmallVector<Value, 4> operands;
  SmallVector<NamedAttribute, 4> attributes;

  // Decode operands
  size_t operandIndex = 0;
  for (; operandIndex < numOperands && wordIndex < words.size();
       ++operandIndex, ++wordIndex) {
    auto arg = getValue(words[wordIndex]);
    if (!arg)
      return emitError(unknownLoc, "unknown result <id>: ") << words[wordIndex];
    operands.push_back(arg);
  }
  if (operandIndex != numOperands) {
    return emitError(
               unknownLoc,
               "found less operands than expected when deserializing for ")
           << opName << "; only " << operandIndex << " of " << numOperands
           << " processed";
  }
  if (wordIndex != words.size()) {
    return emitError(
               unknownLoc,
               "found more operands than expected when deserializing for ")
           << opName << "; only " << wordIndex << " of " << words.size()
           << " processed";
  }

  // Attach attributes from decorations
  if (decorations.count(valueID)) {
    auto attrs = decorations[valueID].getAttrs();
    attributes.append(attrs.begin(), attrs.end());
  }

  // Create the op and update bookkeeping maps
  Location loc = createFileLineColLoc(opBuilder);
  OperationState opState(loc, opName);
  opState.addOperands(operands);
  if (hasResult)
    opState.addTypes(resultTypes);
  opState.addAttributes(attributes);
  Operation *op = opBuilder.createOperation(opState);
  if (hasResult)
    valueMap[valueID] = op->getResult(0);

  if (op->hasTrait<OpTrait::IsTerminator>())
    clearDebugLine();

  return success();
}

LogicalResult spirv::Deserializer::processUndef(ArrayRef<uint32_t> operands) {
  if (operands.size() != 2) {
    return emitError(unknownLoc, "OpUndef instruction must have two operands");
  }
  auto type = getType(operands[0]);
  if (!type) {
    return emitError(unknownLoc, "unknown type <id> with OpUndef instruction");
  }
  undefMap[operands[1]] = type;
  return success();
}

LogicalResult spirv::Deserializer::processExtInst(ArrayRef<uint32_t> operands) {
  if (operands.size() < 4) {
    return emitError(unknownLoc,
                     "OpExtInst must have at least 4 operands, result type "
                     "<id>, result <id>, set <id> and instruction opcode");
  }
  if (!extendedInstSets.count(operands[2])) {
    return emitError(unknownLoc, "undefined set <id> in OpExtInst");
  }
  SmallVector<uint32_t, 4> slicedOperands;
  slicedOperands.append(operands.begin(), std::next(operands.begin(), 2));
  slicedOperands.append(std::next(operands.begin(), 4), operands.end());
  return dispatchToExtensionSetAutogenDeserialization(
      extendedInstSets[operands[2]], operands[3], slicedOperands);
}

namespace mlir {
namespace spirv {

template <>
LogicalResult
Deserializer::processOp<spirv::EntryPointOp>(ArrayRef<uint32_t> words) {
  unsigned wordIndex = 0;
  if (wordIndex >= words.size()) {
    return emitError(unknownLoc,
                     "missing Execution Model specification in OpEntryPoint");
  }
  auto execModel = spirv::ExecutionModelAttr::get(
      context, static_cast<spirv::ExecutionModel>(words[wordIndex++]));
  if (wordIndex >= words.size()) {
    return emitError(unknownLoc, "missing <id> in OpEntryPoint");
  }
  // Get the function <id>
  auto fnID = words[wordIndex++];
  // Get the function name
  auto fnName = decodeStringLiteral(words, wordIndex);
  // Verify that the function <id> matches the fnName
  auto parsedFunc = getFunction(fnID);
  if (!parsedFunc) {
    return emitError(unknownLoc, "no function matching <id> ") << fnID;
  }
  if (parsedFunc.getName() != fnName) {
    return emitError(unknownLoc, "function name mismatch between OpEntryPoint "
                                 "and OpFunction with <id> ")
           << fnID << ": " << fnName << " vs. " << parsedFunc.getName();
  }
  SmallVector<Attribute, 4> interface;
  while (wordIndex < words.size()) {
    auto arg = getGlobalVariable(words[wordIndex]);
    if (!arg) {
      return emitError(unknownLoc, "undefined result <id> ")
             << words[wordIndex] << " while decoding OpEntryPoint";
    }
    interface.push_back(SymbolRefAttr::get(arg.getOperation()));
    wordIndex++;
  }
  opBuilder.create<spirv::EntryPointOp>(
      unknownLoc, execModel, SymbolRefAttr::get(opBuilder.getContext(), fnName),
      opBuilder.getArrayAttr(interface));
  return success();
}

template <>
LogicalResult
Deserializer::processOp<spirv::ExecutionModeOp>(ArrayRef<uint32_t> words) {
  unsigned wordIndex = 0;
  if (wordIndex >= words.size()) {
    return emitError(unknownLoc,
                     "missing function result <id> in OpExecutionMode");
  }
  // Get the function <id> to get the name of the function
  auto fnID = words[wordIndex++];
  auto fn = getFunction(fnID);
  if (!fn) {
    return emitError(unknownLoc, "no function matching <id> ") << fnID;
  }
  // Get the Execution mode
  if (wordIndex >= words.size()) {
    return emitError(unknownLoc, "missing Execution Mode in OpExecutionMode");
  }
  auto execMode = spirv::ExecutionModeAttr::get(
      context, static_cast<spirv::ExecutionMode>(words[wordIndex++]));

  // Get the values
  SmallVector<Attribute, 4> attrListElems;
  while (wordIndex < words.size()) {
    attrListElems.push_back(opBuilder.getI32IntegerAttr(words[wordIndex++]));
  }
  auto values = opBuilder.getArrayAttr(attrListElems);
  opBuilder.create<spirv::ExecutionModeOp>(
      unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), fn.getName()),
      execMode, values);
  return success();
}

template <>
LogicalResult
Deserializer::processOp<spirv::ControlBarrierOp>(ArrayRef<uint32_t> operands) {
  if (operands.size() != 3) {
    return emitError(
        unknownLoc,
        "OpControlBarrier must have execution scope <id>, memory scope <id> "
        "and memory semantics <id>");
  }

  SmallVector<IntegerAttr, 3> argAttrs;
  for (auto operand : operands) {
    auto argAttr = getConstantInt(operand);
    if (!argAttr) {
      return emitError(unknownLoc,
                       "expected 32-bit integer constant from <id> ")
             << operand << " for OpControlBarrier";
    }
    argAttrs.push_back(argAttr);
  }

  opBuilder.create<spirv::ControlBarrierOp>(
      unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
      argAttrs[1].cast<spirv::ScopeAttr>(),
      argAttrs[2].cast<spirv::MemorySemanticsAttr>());

  return success();
}

template <>
LogicalResult
Deserializer::processOp<spirv::FunctionCallOp>(ArrayRef<uint32_t> operands) {
  if (operands.size() < 3) {
    return emitError(unknownLoc,
                     "OpFunctionCall must have at least 3 operands");
  }

  Type resultType = getType(operands[0]);
  if (!resultType) {
    return emitError(unknownLoc, "undefined result type from <id> ")
           << operands[0];
  }

  // Use null type to mean no result type.
  if (isVoidType(resultType))
    resultType = nullptr;

  auto resultID = operands[1];
  auto functionID = operands[2];

  auto functionName = getFunctionSymbol(functionID);

  SmallVector<Value, 4> arguments;
  for (auto operand : llvm::drop_begin(operands, 3)) {
    auto value = getValue(operand);
    if (!value) {
      return emitError(unknownLoc, "unknown <id> ")
             << operand << " used by OpFunctionCall";
    }
    arguments.push_back(value);
  }

  auto opFunctionCall = opBuilder.create<spirv::FunctionCallOp>(
      unknownLoc, resultType,
      SymbolRefAttr::get(opBuilder.getContext(), functionName), arguments);

  if (resultType)
    valueMap[resultID] = opFunctionCall.getResult(0);
  return success();
}

template <>
LogicalResult
Deserializer::processOp<spirv::MemoryBarrierOp>(ArrayRef<uint32_t> operands) {
  if (operands.size() != 2) {
    return emitError(unknownLoc, "OpMemoryBarrier must have memory scope <id> "
                                 "and memory semantics <id>");
  }

  SmallVector<IntegerAttr, 2> argAttrs;
  for (auto operand : operands) {
    auto argAttr = getConstantInt(operand);
    if (!argAttr) {
      return emitError(unknownLoc,
                       "expected 32-bit integer constant from <id> ")
             << operand << " for OpMemoryBarrier";
    }
    argAttrs.push_back(argAttr);
  }

  opBuilder.create<spirv::MemoryBarrierOp>(
      unknownLoc, argAttrs[0].cast<spirv::ScopeAttr>(),
      argAttrs[1].cast<spirv::MemorySemanticsAttr>());
  return success();
}

template <>
LogicalResult
Deserializer::processOp<spirv::CopyMemoryOp>(ArrayRef<uint32_t> words) {
  SmallVector<Type, 1> resultTypes;
  size_t wordIndex = 0;
  SmallVector<Value, 4> operands;
  SmallVector<NamedAttribute, 4> attributes;

  if (wordIndex < words.size()) {
    auto arg = getValue(words[wordIndex]);

    if (!arg) {
      return emitError(unknownLoc, "unknown result <id> : ")
             << words[wordIndex];
    }

    operands.push_back(arg);
    wordIndex++;
  }

  if (wordIndex < words.size()) {
    auto arg = getValue(words[wordIndex]);

    if (!arg) {
      return emitError(unknownLoc, "unknown result <id> : ")
             << words[wordIndex];
    }

    operands.push_back(arg);
    wordIndex++;
  }

  bool isAlignedAttr = false;

  if (wordIndex < words.size()) {
    auto attrValue = words[wordIndex++];
    attributes.push_back(opBuilder.getNamedAttr(
        "memory_access", opBuilder.getI32IntegerAttr(attrValue)));
    isAlignedAttr = (attrValue == 2);
  }

  if (isAlignedAttr && wordIndex < words.size()) {
    attributes.push_back(opBuilder.getNamedAttr(
        "alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
  }

  if (wordIndex < words.size()) {
    attributes.push_back(opBuilder.getNamedAttr(
        "source_memory_access",
        opBuilder.getI32IntegerAttr(words[wordIndex++])));
  }

  if (wordIndex < words.size()) {
    attributes.push_back(opBuilder.getNamedAttr(
        "source_alignment", opBuilder.getI32IntegerAttr(words[wordIndex++])));
  }

  if (wordIndex != words.size()) {
    return emitError(unknownLoc,
                     "found more operands than expected when deserializing "
                     "spirv::CopyMemoryOp, only ")
           << wordIndex << " of " << words.size() << " processed";
  }

  Location loc = createFileLineColLoc(opBuilder);
  opBuilder.create<spirv::CopyMemoryOp>(loc, resultTypes, operands, attributes);

  return success();
}

// Pull in auto-generated Deserializer::dispatchToAutogenDeserialization() and
// various Deserializer::processOp<...>() specializations.
#define GET_DESERIALIZATION_FNS
#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"

} // namespace spirv
} // namespace mlir
