1*b1fa5ac3SMircea Trofin //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2*b1fa5ac3SMircea Trofin //
3*b1fa5ac3SMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*b1fa5ac3SMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
5*b1fa5ac3SMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*b1fa5ac3SMircea Trofin //
7*b1fa5ac3SMircea Trofin //===----------------------------------------------------------------------===//
8*b1fa5ac3SMircea Trofin //
9*b1fa5ac3SMircea Trofin // Implementation file for the abstraction of a tensor type, and JSON loading
10*b1fa5ac3SMircea Trofin // utils.
11*b1fa5ac3SMircea Trofin //
12*b1fa5ac3SMircea Trofin //===----------------------------------------------------------------------===//
13*b1fa5ac3SMircea Trofin #include "llvm/Config/config.h"
14*b1fa5ac3SMircea Trofin
15*b1fa5ac3SMircea Trofin #include "llvm/ADT/Twine.h"
16*b1fa5ac3SMircea Trofin #include "llvm/Analysis/TensorSpec.h"
17*b1fa5ac3SMircea Trofin #include "llvm/Support/CommandLine.h"
18*b1fa5ac3SMircea Trofin #include "llvm/Support/Debug.h"
19*b1fa5ac3SMircea Trofin #include "llvm/Support/JSON.h"
20*b1fa5ac3SMircea Trofin #include "llvm/Support/ManagedStatic.h"
21*b1fa5ac3SMircea Trofin #include "llvm/Support/MemoryBuffer.h"
22*b1fa5ac3SMircea Trofin #include "llvm/Support/Path.h"
23*b1fa5ac3SMircea Trofin #include "llvm/Support/raw_ostream.h"
24*b1fa5ac3SMircea Trofin #include <cassert>
25*b1fa5ac3SMircea Trofin #include <numeric>
26*b1fa5ac3SMircea Trofin
27*b1fa5ac3SMircea Trofin using namespace llvm;
28*b1fa5ac3SMircea Trofin
29*b1fa5ac3SMircea Trofin namespace llvm {
30*b1fa5ac3SMircea Trofin
31*b1fa5ac3SMircea Trofin #define TFUTILS_GETDATATYPE_IMPL(T, E) \
32*b1fa5ac3SMircea Trofin template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
33*b1fa5ac3SMircea Trofin
SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)34*b1fa5ac3SMircea Trofin SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
35*b1fa5ac3SMircea Trofin
36*b1fa5ac3SMircea Trofin #undef TFUTILS_GETDATATYPE_IMPL
37*b1fa5ac3SMircea Trofin
38*b1fa5ac3SMircea Trofin TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
39*b1fa5ac3SMircea Trofin size_t ElementSize, const std::vector<int64_t> &Shape)
40*b1fa5ac3SMircea Trofin : Name(Name), Port(Port), Type(Type), Shape(Shape),
41*b1fa5ac3SMircea Trofin ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
42*b1fa5ac3SMircea Trofin std::multiplies<int64_t>())),
43*b1fa5ac3SMircea Trofin ElementSize(ElementSize) {}
44*b1fa5ac3SMircea Trofin
getTensorSpecFromJSON(LLVMContext & Ctx,const json::Value & Value)45*b1fa5ac3SMircea Trofin Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
46*b1fa5ac3SMircea Trofin const json::Value &Value) {
47*b1fa5ac3SMircea Trofin auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
48*b1fa5ac3SMircea Trofin std::string S;
49*b1fa5ac3SMircea Trofin llvm::raw_string_ostream OS(S);
50*b1fa5ac3SMircea Trofin OS << Value;
51*b1fa5ac3SMircea Trofin Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
52*b1fa5ac3SMircea Trofin return None;
53*b1fa5ac3SMircea Trofin };
54*b1fa5ac3SMircea Trofin // FIXME: accept a Path as a parameter, and use it for error reporting.
55*b1fa5ac3SMircea Trofin json::Path::Root Root("tensor_spec");
56*b1fa5ac3SMircea Trofin json::ObjectMapper Mapper(Value, Root);
57*b1fa5ac3SMircea Trofin if (!Mapper)
58*b1fa5ac3SMircea Trofin return EmitError("Value is not a dict");
59*b1fa5ac3SMircea Trofin
60*b1fa5ac3SMircea Trofin std::string TensorName;
61*b1fa5ac3SMircea Trofin int TensorPort = -1;
62*b1fa5ac3SMircea Trofin std::string TensorType;
63*b1fa5ac3SMircea Trofin std::vector<int64_t> TensorShape;
64*b1fa5ac3SMircea Trofin
65*b1fa5ac3SMircea Trofin if (!Mapper.map<std::string>("name", TensorName))
66*b1fa5ac3SMircea Trofin return EmitError("'name' property not present or not a string");
67*b1fa5ac3SMircea Trofin if (!Mapper.map<std::string>("type", TensorType))
68*b1fa5ac3SMircea Trofin return EmitError("'type' property not present or not a string");
69*b1fa5ac3SMircea Trofin if (!Mapper.map<int>("port", TensorPort))
70*b1fa5ac3SMircea Trofin return EmitError("'port' property not present or not an int");
71*b1fa5ac3SMircea Trofin if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
72*b1fa5ac3SMircea Trofin return EmitError("'shape' property not present or not an int array");
73*b1fa5ac3SMircea Trofin
74*b1fa5ac3SMircea Trofin #define PARSE_TYPE(T, E) \
75*b1fa5ac3SMircea Trofin if (TensorType == #T) \
76*b1fa5ac3SMircea Trofin return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
77*b1fa5ac3SMircea Trofin SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
78*b1fa5ac3SMircea Trofin #undef PARSE_TYPE
79*b1fa5ac3SMircea Trofin return None;
80*b1fa5ac3SMircea Trofin }
81*b1fa5ac3SMircea Trofin
82*b1fa5ac3SMircea Trofin Optional<std::vector<LoggedFeatureSpec>>
loadOutputSpecs(LLVMContext & Ctx,StringRef ExpectedDecisionName,StringRef ModelPath,StringRef SpecFileOverride)83*b1fa5ac3SMircea Trofin loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
84*b1fa5ac3SMircea Trofin StringRef ModelPath, StringRef SpecFileOverride) {
85*b1fa5ac3SMircea Trofin SmallVector<char, 128> OutputSpecsPath;
86*b1fa5ac3SMircea Trofin StringRef FileName = SpecFileOverride;
87*b1fa5ac3SMircea Trofin if (FileName.empty()) {
88*b1fa5ac3SMircea Trofin llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
89*b1fa5ac3SMircea Trofin FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
90*b1fa5ac3SMircea Trofin }
91*b1fa5ac3SMircea Trofin
92*b1fa5ac3SMircea Trofin auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
93*b1fa5ac3SMircea Trofin if (!BufferOrError) {
94*b1fa5ac3SMircea Trofin Ctx.emitError("Error opening output specs file: " + FileName + " : " +
95*b1fa5ac3SMircea Trofin BufferOrError.getError().message());
96*b1fa5ac3SMircea Trofin return None;
97*b1fa5ac3SMircea Trofin }
98*b1fa5ac3SMircea Trofin auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
99*b1fa5ac3SMircea Trofin if (!ParsedJSONValues) {
100*b1fa5ac3SMircea Trofin Ctx.emitError("Could not parse specs file: " + FileName);
101*b1fa5ac3SMircea Trofin return None;
102*b1fa5ac3SMircea Trofin }
103*b1fa5ac3SMircea Trofin auto ValuesArray = ParsedJSONValues->getAsArray();
104*b1fa5ac3SMircea Trofin if (!ValuesArray) {
105*b1fa5ac3SMircea Trofin Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
106*b1fa5ac3SMircea Trofin "logging_name:<name>} dictionaries");
107*b1fa5ac3SMircea Trofin return None;
108*b1fa5ac3SMircea Trofin }
109*b1fa5ac3SMircea Trofin std::vector<LoggedFeatureSpec> Ret;
110*b1fa5ac3SMircea Trofin for (const auto &Value : *ValuesArray)
111*b1fa5ac3SMircea Trofin if (const auto *Obj = Value.getAsObject())
112*b1fa5ac3SMircea Trofin if (const auto *SpecPart = Obj->get("tensor_spec"))
113*b1fa5ac3SMircea Trofin if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
114*b1fa5ac3SMircea Trofin if (auto LoggingName = Obj->getString("logging_name")) {
115*b1fa5ac3SMircea Trofin if (!TensorSpec->isElementType<int64_t>() &&
116*b1fa5ac3SMircea Trofin !TensorSpec->isElementType<int32_t>() &&
117*b1fa5ac3SMircea Trofin !TensorSpec->isElementType<float>()) {
118*b1fa5ac3SMircea Trofin Ctx.emitError(
119*b1fa5ac3SMircea Trofin "Only int64, int32, and float tensors are supported. "
120*b1fa5ac3SMircea Trofin "Found unsupported type for tensor named " +
121*b1fa5ac3SMircea Trofin TensorSpec->name());
122*b1fa5ac3SMircea Trofin return None;
123*b1fa5ac3SMircea Trofin }
124*b1fa5ac3SMircea Trofin Ret.push_back({*TensorSpec, LoggingName->str()});
125*b1fa5ac3SMircea Trofin }
126*b1fa5ac3SMircea Trofin
127*b1fa5ac3SMircea Trofin if (ValuesArray->size() != Ret.size()) {
128*b1fa5ac3SMircea Trofin Ctx.emitError(
129*b1fa5ac3SMircea Trofin "Unable to parse output spec. It should be a json file containing an "
130*b1fa5ac3SMircea Trofin "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
131*b1fa5ac3SMircea Trofin "with a json object describing a TensorSpec; and a 'logging_name' key, "
132*b1fa5ac3SMircea Trofin "which is a string to use as name when logging this tensor in the "
133*b1fa5ac3SMircea Trofin "training log.");
134*b1fa5ac3SMircea Trofin return None;
135*b1fa5ac3SMircea Trofin }
136*b1fa5ac3SMircea Trofin if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
137*b1fa5ac3SMircea Trofin Ctx.emitError("The first output spec must describe the decision tensor, "
138*b1fa5ac3SMircea Trofin "and must have the logging_name " +
139*b1fa5ac3SMircea Trofin StringRef(ExpectedDecisionName));
140*b1fa5ac3SMircea Trofin return None;
141*b1fa5ac3SMircea Trofin }
142*b1fa5ac3SMircea Trofin return Ret;
143*b1fa5ac3SMircea Trofin }
144*b1fa5ac3SMircea Trofin } // namespace llvm
145