1caf395eeSMircea Trofin //===- TFUtils.cpp - tensorflow evaluation utilities ----------------------===//
2caf395eeSMircea Trofin //
3c874dd53SChristopher Di Bella // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c874dd53SChristopher Di Bella // See https://llvm.org/LICENSE.txt for license information.
5c874dd53SChristopher Di Bella // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6caf395eeSMircea Trofin //
7caf395eeSMircea Trofin //===----------------------------------------------------------------------===//
8caf395eeSMircea Trofin //
9caf395eeSMircea Trofin // This file implements utilities for interfacing with tensorflow C APIs.
10caf395eeSMircea Trofin //
11caf395eeSMircea Trofin //===----------------------------------------------------------------------===//
124fe912f1SNico Weber #include "llvm/Config/config.h"
134fe912f1SNico Weber #if defined(LLVM_HAVE_TF_API)
14caf395eeSMircea Trofin
15caf395eeSMircea Trofin #include "llvm/ADT/Twine.h"
164b1b109cSMircea Trofin #include "llvm/Analysis/Utils/TFUtils.h"
171f5dceb1SMircea Trofin #include "llvm/Support/Base64.h"
1855e2d206SMircea Trofin #include "llvm/Support/CommandLine.h"
19caf395eeSMircea Trofin #include "llvm/Support/Debug.h"
204b1b109cSMircea Trofin #include "llvm/Support/JSON.h"
21b51e844fSMircea Trofin #include "llvm/Support/MemoryBuffer.h"
228ab2353aSMircea Trofin #include "llvm/Support/Path.h"
23caf395eeSMircea Trofin #include "llvm/Support/raw_ostream.h"
24caf395eeSMircea Trofin
251f5dceb1SMircea Trofin #include "google/protobuf/struct.pb.h"
2655e2d206SMircea Trofin #include "google/protobuf/text_format.h"
274f763b21SMircea Trofin #include "tensorflow/c/c_api.h"
28caf395eeSMircea Trofin #include "tensorflow/c/c_api_experimental.h"
2955e2d206SMircea Trofin #include "tensorflow/core/example/example.pb.h"
30caf395eeSMircea Trofin #include <cassert>
3190b9c49cSMircea Trofin #include <numeric>
32caf395eeSMircea Trofin
33caf395eeSMircea Trofin using namespace llvm;
34caf395eeSMircea Trofin
3555e12f70SMircea Trofin using google::protobuf::Message;
3655e12f70SMircea Trofin using google::protobuf::TextFormat;
3755e12f70SMircea Trofin
3855e2d206SMircea Trofin static cl::opt<bool>
3955e2d206SMircea Trofin ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden,
4055e2d206SMircea Trofin cl::desc("Output textual (human-readable) protobuf."));
4155e2d206SMircea Trofin
42caf395eeSMircea Trofin namespace {
43caf395eeSMircea Trofin
444f763b21SMircea Trofin using TFGraphPtr = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>;
454f763b21SMircea Trofin using TFSessionOptionsPtr =
464f763b21SMircea Trofin std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>;
474f763b21SMircea Trofin using TFStatusPtr = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
484f763b21SMircea Trofin
49caf395eeSMircea Trofin struct TFInitializer {
TFInitializer__anon5e75051e0111::TFInitializer50caf395eeSMircea Trofin TFInitializer() {
51caf395eeSMircea Trofin int Argc = 1;
52caf395eeSMircea Trofin const char *Name = "";
53caf395eeSMircea Trofin const char **NamePtr = &Name;
54caf395eeSMircea Trofin TF_InitMain(Name, &Argc, const_cast<char ***>(&NamePtr));
55caf395eeSMircea Trofin }
56caf395eeSMircea Trofin };
57caf395eeSMircea Trofin
ensureInitTF()58*ede60037SNicolai Hähnle bool ensureInitTF() {
59*ede60037SNicolai Hähnle static TFInitializer TFLibInitializer;
60*ede60037SNicolai Hähnle return true;
61*ede60037SNicolai Hähnle }
62caf395eeSMircea Trofin
createTFGraph()634f763b21SMircea Trofin TFGraphPtr createTFGraph() {
644f763b21SMircea Trofin return TFGraphPtr(TF_NewGraph(), &TF_DeleteGraph);
65caf395eeSMircea Trofin }
66caf395eeSMircea Trofin
createTFStatus()674f763b21SMircea Trofin TFStatusPtr createTFStatus() {
684f763b21SMircea Trofin return TFStatusPtr(TF_NewStatus(), &TF_DeleteStatus);
69caf395eeSMircea Trofin }
70caf395eeSMircea Trofin
createTFSessionOptions()714f763b21SMircea Trofin TFSessionOptionsPtr createTFSessionOptions() {
724f763b21SMircea Trofin return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
73caf395eeSMircea Trofin }
741f5dceb1SMircea Trofin
serialize(const Message & SE,std::string * OutStr)751f5dceb1SMircea Trofin void serialize(const Message &SE, std::string *OutStr) {
761f5dceb1SMircea Trofin if (ProtobufTextMode) {
771f5dceb1SMircea Trofin TextFormat::PrintToString(SE, OutStr);
781f5dceb1SMircea Trofin } else {
791f5dceb1SMircea Trofin *OutStr = SE.SerializeAsString();
801f5dceb1SMircea Trofin }
811f5dceb1SMircea Trofin }
82e4794ff5SMircea Trofin
getTFTypeIndex(TensorType TType)83e4794ff5SMircea Trofin int getTFTypeIndex(TensorType TType) {
84e4794ff5SMircea Trofin switch (TType) {
85e4794ff5SMircea Trofin case TensorType::Double:
86e4794ff5SMircea Trofin return TF_DOUBLE;
87e4794ff5SMircea Trofin case TensorType::Float:
88e4794ff5SMircea Trofin return TF_FLOAT;
89e4794ff5SMircea Trofin case TensorType::Int8:
90e4794ff5SMircea Trofin return TF_INT8;
91e4794ff5SMircea Trofin case TensorType::UInt8:
92e4794ff5SMircea Trofin return TF_UINT8;
93e4794ff5SMircea Trofin case TensorType::Int16:
94e4794ff5SMircea Trofin return TF_INT16;
95e4794ff5SMircea Trofin case TensorType::UInt16:
96e4794ff5SMircea Trofin return TF_UINT16;
97e4794ff5SMircea Trofin case TensorType::Int32:
98e4794ff5SMircea Trofin return TF_INT32;
99e4794ff5SMircea Trofin case TensorType::UInt32:
100e4794ff5SMircea Trofin return TF_UINT32;
101e4794ff5SMircea Trofin case TensorType::Int64:
102e4794ff5SMircea Trofin return TF_INT64;
103e4794ff5SMircea Trofin case TensorType::UInt64:
104e4794ff5SMircea Trofin return TF_UINT64;
105e4794ff5SMircea Trofin case TensorType::Invalid:
106e4794ff5SMircea Trofin llvm_unreachable("Unknown tensor type");
107e4794ff5SMircea Trofin }
108e4794ff5SMircea Trofin }
109caf395eeSMircea Trofin } // namespace
110caf395eeSMircea Trofin
1114f763b21SMircea Trofin namespace llvm {
1124f763b21SMircea Trofin class EvaluationResultImpl {
1134f763b21SMircea Trofin public:
EvaluationResultImpl(size_t OutputSize)1144f763b21SMircea Trofin EvaluationResultImpl(size_t OutputSize)
1154f763b21SMircea Trofin : OutputSize(OutputSize), Output(OutputSize){};
1164f763b21SMircea Trofin
~EvaluationResultImpl()1174f763b21SMircea Trofin ~EvaluationResultImpl() {
1184f763b21SMircea Trofin for (auto *P : Output)
1194f763b21SMircea Trofin if (P)
1204f763b21SMircea Trofin TF_DeleteTensor(P);
1214f763b21SMircea Trofin }
1224f763b21SMircea Trofin
1234f763b21SMircea Trofin EvaluationResultImpl(const EvaluationResultImpl &) = delete;
1244f763b21SMircea Trofin EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
getOutput()1254f763b21SMircea Trofin std::vector<TF_Tensor *> &getOutput() { return Output; }
1264f763b21SMircea Trofin
1274f763b21SMircea Trofin private:
1284f763b21SMircea Trofin const size_t OutputSize;
1294f763b21SMircea Trofin std::vector<TF_Tensor *> Output;
1304f763b21SMircea Trofin };
1314f763b21SMircea Trofin
1324f763b21SMircea Trofin class TFModelEvaluatorImpl {
1334f763b21SMircea Trofin public:
1344f763b21SMircea Trofin TFModelEvaluatorImpl(StringRef SavedModelPath,
13571059257SMircea Trofin const std::vector<TensorSpec> &InputSpecs,
136b51e844fSMircea Trofin function_ref<TensorSpec(size_t)> GetOutputSpecs,
137b51e844fSMircea Trofin size_t OutputSpecsSize, const char *Tags);
1384f763b21SMircea Trofin
isValid() const1394f763b21SMircea Trofin bool isValid() const { return IsValid; }
OutputSize() const1404f763b21SMircea Trofin size_t OutputSize() const { return OutputFeed.size(); }
1414f763b21SMircea Trofin
evaluate(TF_Tensor ** Output,TF_Status * Status)1424f763b21SMircea Trofin void evaluate(TF_Tensor **Output, TF_Status *Status) {
1434f763b21SMircea Trofin TF_SessionRun(Session, nullptr, InputFeed.data(), Input.data(),
1444f763b21SMircea Trofin Input.size(), OutputFeed.data(), Output, OutputFeed.size(),
1454f763b21SMircea Trofin nullptr, 0, nullptr, Status);
1464f763b21SMircea Trofin }
1474f763b21SMircea Trofin
1484f763b21SMircea Trofin void initInput(size_t Index, TF_DataType Type,
1494f763b21SMircea Trofin const std::vector<int64_t> &Dimensions);
getInput() const1504f763b21SMircea Trofin const std::vector<TF_Tensor *> &getInput() const { return Input; }
1514f763b21SMircea Trofin
1524f763b21SMircea Trofin ~TFModelEvaluatorImpl();
1534f763b21SMircea Trofin
1544f763b21SMircea Trofin private:
1554f763b21SMircea Trofin /// The objects necessary for carrying out an evaluation of the SavedModel.
1564f763b21SMircea Trofin /// They are expensive to set up, and we maintain them accross all the
1574f763b21SMircea Trofin /// evaluations of the model.
1584f763b21SMircea Trofin TF_Session *Session = nullptr;
1594f763b21SMircea Trofin TFGraphPtr Graph;
1604f763b21SMircea Trofin TFSessionOptionsPtr Options;
1614f763b21SMircea Trofin
1624f763b21SMircea Trofin /// The specification of the input nodes.
1634f763b21SMircea Trofin std::vector<TF_Output> InputFeed;
1644f763b21SMircea Trofin
1654f763b21SMircea Trofin /// The input tensors. They must match by index of the corresponding InputFeed
1664f763b21SMircea Trofin /// value. We set up the tensors once and just mutate theirs scalars before
1674f763b21SMircea Trofin /// each evaluation. The input tensors keep their value after an evaluation.
1684f763b21SMircea Trofin std::vector<TF_Tensor *> Input;
1694f763b21SMircea Trofin
1704f763b21SMircea Trofin /// The specification of the output nodes. When evaluating, the tensors in the
1714f763b21SMircea Trofin /// output tensor vector must match by index the corresponding element in the
1724f763b21SMircea Trofin /// OutputFeed.
1734f763b21SMircea Trofin std::vector<TF_Output> OutputFeed;
1744f763b21SMircea Trofin
invalidate()1754f763b21SMircea Trofin void invalidate() { IsValid = false; }
1764f763b21SMircea Trofin
1774f763b21SMircea Trofin bool IsValid = true;
1784f763b21SMircea Trofin
1794f763b21SMircea Trofin /// Reusable utility for ensuring we can bind the requested Name to a node in
1804f763b21SMircea Trofin /// the SavedModel Graph.
18171059257SMircea Trofin bool checkReportAndInvalidate(const TF_Output &Output,
18271059257SMircea Trofin const TensorSpec &OutputSpec);
1834f763b21SMircea Trofin };
18455e12f70SMircea Trofin
18555e12f70SMircea Trofin class LoggerDataImpl {
18655e12f70SMircea Trofin const std::vector<LoggedFeatureSpec> LoggedFeatureSpecs;
18755e12f70SMircea Trofin const TensorSpec RewardSpec;
188ae1a2a09SMircea Trofin const bool IncludeReward;
18955e12f70SMircea Trofin
190ae1a2a09SMircea Trofin std::vector<tensorflow::FeatureList> FeatureLists;
191ae1a2a09SMircea Trofin tensorflow::FeatureList Reward;
192ae1a2a09SMircea Trofin
isSelfConsistent(const tensorflow::SequenceExample & SE,size_t NrRecords) const193ae1a2a09SMircea Trofin bool isSelfConsistent(const tensorflow::SequenceExample &SE,
194ae1a2a09SMircea Trofin size_t NrRecords) const {
195ae1a2a09SMircea Trofin bool Ret = true;
196ae1a2a09SMircea Trofin for (const auto &TSpecs : LoggedFeatureSpecs) {
197ae1a2a09SMircea Trofin const auto &Name = TSpecs.getLoggingName();
198ae1a2a09SMircea Trofin const auto &FL = SE.feature_lists().feature_list().at(Name).feature();
199ae1a2a09SMircea Trofin if (NrRecords != static_cast<size_t>(FL.size())) {
200ae1a2a09SMircea Trofin dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected "
201ae1a2a09SMircea Trofin << NrRecords << " got " << FL.size() << "\n";
202ae1a2a09SMircea Trofin Ret = false;
203ae1a2a09SMircea Trofin }
204ae1a2a09SMircea Trofin }
205ae1a2a09SMircea Trofin if (IncludeReward && static_cast<size_t>(SE.feature_lists()
206ae1a2a09SMircea Trofin .feature_list()
207ae1a2a09SMircea Trofin .at(RewardSpec.name())
208ae1a2a09SMircea Trofin .feature()
209ae1a2a09SMircea Trofin .size()) != NrRecords) {
210ae1a2a09SMircea Trofin dbgs() << "[TF-UTILS]: reward is missing records.\n";
211ae1a2a09SMircea Trofin Ret = false;
212ae1a2a09SMircea Trofin }
213ae1a2a09SMircea Trofin return Ret;
214ae1a2a09SMircea Trofin }
215ae1a2a09SMircea Trofin
transferLog(tensorflow::SequenceExample & SE)216ae1a2a09SMircea Trofin void transferLog(tensorflow::SequenceExample &SE) {
217ae1a2a09SMircea Trofin auto *FL = SE.mutable_feature_lists()->mutable_feature_list();
218ae1a2a09SMircea Trofin if (IncludeReward)
2198dc3fe0cSMircea Trofin (*FL)[RewardSpec.name()] = std::move(Reward);
220ae1a2a09SMircea Trofin assert(FeatureLists.size() == LoggedFeatureSpecs.size());
221ae1a2a09SMircea Trofin for (size_t I = 0; I < FeatureLists.size(); ++I) {
222ae1a2a09SMircea Trofin const auto &LFS = LoggedFeatureSpecs[I];
2238dc3fe0cSMircea Trofin (*FL)[LFS.getLoggingName()] = std::move(FeatureLists[I]);
224ae1a2a09SMircea Trofin }
225ae1a2a09SMircea Trofin }
22655e12f70SMircea Trofin
22755e12f70SMircea Trofin public:
LoggerDataImpl(const std::vector<LoggedFeatureSpec> & LoggedSpecs,const TensorSpec & RewardSpec,bool IncludeReward)22855e12f70SMircea Trofin LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs,
22955e12f70SMircea Trofin const TensorSpec &RewardSpec, bool IncludeReward)
230ae1a2a09SMircea Trofin : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
231ae1a2a09SMircea Trofin IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
23255e12f70SMircea Trofin
233ae1a2a09SMircea Trofin // flush the logged info to a stream and clear the log contents.
flush(std::string * Str)2341f5dceb1SMircea Trofin void flush(std::string *Str) {
235ae1a2a09SMircea Trofin size_t NrRecords = getNrRecords();
236510402c2SMircea Trofin (void)NrRecords;
237ae1a2a09SMircea Trofin tensorflow::SequenceExample SE;
238ae1a2a09SMircea Trofin transferLog(SE);
239ae1a2a09SMircea Trofin assert(isSelfConsistent(SE, NrRecords));
2401f5dceb1SMircea Trofin serialize(SE, Str);
24155e12f70SMircea Trofin }
24255e12f70SMircea Trofin
addNewTensor(size_t FeatureID)24355e12f70SMircea Trofin char *addNewTensor(size_t FeatureID) {
24455e12f70SMircea Trofin const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec;
24555e12f70SMircea Trofin if (Spec.isElementType<float>()) {
24655e12f70SMircea Trofin auto *RF = FeatureLists[FeatureID]
247ae1a2a09SMircea Trofin .add_feature()
24855e12f70SMircea Trofin ->mutable_float_list()
24955e12f70SMircea Trofin ->mutable_value();
25055e12f70SMircea Trofin RF->Resize(Spec.getElementCount(), 0.0);
25155e12f70SMircea Trofin return reinterpret_cast<char *>(RF->mutable_data());
25255e12f70SMircea Trofin } else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) {
25355e12f70SMircea Trofin auto *RF = FeatureLists[FeatureID]
254ae1a2a09SMircea Trofin .add_feature()
25555e12f70SMircea Trofin ->mutable_int64_list()
25655e12f70SMircea Trofin ->mutable_value();
25755e12f70SMircea Trofin RF->Resize(Spec.getElementCount(), 0);
25855e12f70SMircea Trofin return reinterpret_cast<char *>(RF->mutable_data());
25955e12f70SMircea Trofin }
26055e12f70SMircea Trofin llvm_unreachable("Unsupported tensor type.");
26155e12f70SMircea Trofin }
26255e12f70SMircea Trofin
logReward(T Value)26355e12f70SMircea Trofin template <typename T> void logReward(T Value) {
264ae1a2a09SMircea Trofin assert(IncludeReward);
26555e12f70SMircea Trofin if (RewardSpec.isElementType<float>())
266ae1a2a09SMircea Trofin Reward.add_feature()->mutable_float_list()->add_value(Value);
26755e12f70SMircea Trofin else if (RewardSpec.isElementType<int32_t>() ||
26855e12f70SMircea Trofin RewardSpec.isElementType<int64_t>())
269ae1a2a09SMircea Trofin Reward.add_feature()->mutable_int64_list()->add_value(Value);
27055e12f70SMircea Trofin else
27155e12f70SMircea Trofin llvm_unreachable("Unsupported tensor type.");
27255e12f70SMircea Trofin }
27355e12f70SMircea Trofin
getNrRecords() const27455e12f70SMircea Trofin size_t getNrRecords() const {
275ae1a2a09SMircea Trofin return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
27655e12f70SMircea Trofin }
27755e12f70SMircea Trofin };
2784f763b21SMircea Trofin } // namespace llvm
2794f763b21SMircea Trofin
TFModelEvaluatorImpl(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,function_ref<TensorSpec (size_t)> GetOutputSpecs,size_t OutputSpecsSize,const char * Tags="serve")2804f763b21SMircea Trofin TFModelEvaluatorImpl::TFModelEvaluatorImpl(
28171059257SMircea Trofin StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
282b51e844fSMircea Trofin function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
283b51e844fSMircea Trofin const char *Tags = "serve")
284caf395eeSMircea Trofin : Graph(createTFGraph()), Options(createTFSessionOptions()),
28571059257SMircea Trofin InputFeed(InputSpecs.size()), Input(InputSpecs.size()),
286b51e844fSMircea Trofin OutputFeed(OutputSpecsSize) {
287caf395eeSMircea Trofin if (!ensureInitTF()) {
288caf395eeSMircea Trofin errs() << "Tensorflow should have been initialized";
289caf395eeSMircea Trofin return;
290caf395eeSMircea Trofin }
291caf395eeSMircea Trofin auto Status = createTFStatus();
292caf395eeSMircea Trofin
293caf395eeSMircea Trofin Session = TF_LoadSessionFromSavedModel(Options.get(), nullptr,
294caf395eeSMircea Trofin SavedModelPath.str().c_str(), &Tags, 1,
295caf395eeSMircea Trofin Graph.get(), nullptr, Status.get());
296caf395eeSMircea Trofin if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
297caf395eeSMircea Trofin errs() << TF_Message(Status.get());
2984f763b21SMircea Trofin invalidate();
299caf395eeSMircea Trofin }
300c35ad9eeSMircea Trofin size_t NrSupported = 0;
30171059257SMircea Trofin for (size_t I = 0; I < InputSpecs.size(); ++I) {
30271059257SMircea Trofin auto &InputSpec = InputSpecs[I];
303caf395eeSMircea Trofin InputFeed[I] = {
30471059257SMircea Trofin TF_GraphOperationByName(Graph.get(), (InputSpec.name()).c_str()),
30571059257SMircea Trofin InputSpec.port()};
306c35ad9eeSMircea Trofin if (!InputFeed[I].oper) {
307c35ad9eeSMircea Trofin continue;
308c35ad9eeSMircea Trofin }
309c35ad9eeSMircea Trofin if (NrSupported++ != I) {
310c35ad9eeSMircea Trofin errs()
311c35ad9eeSMircea Trofin << "Unsupported features must be placed at the end of the InputSpecs";
312c35ad9eeSMircea Trofin invalidate();
313c35ad9eeSMircea Trofin return;
314c35ad9eeSMircea Trofin }
31571059257SMircea Trofin if (!checkReportAndInvalidate(InputFeed[I], InputSpec))
316caf395eeSMircea Trofin return;
317e4794ff5SMircea Trofin initInput(I, static_cast<TF_DataType>(getTFTypeIndex(InputSpec.type())),
31871059257SMircea Trofin InputSpec.shape());
319caf395eeSMircea Trofin }
320c35ad9eeSMircea Trofin InputFeed.resize(NrSupported);
321c35ad9eeSMircea Trofin Input.resize(NrSupported);
322c35ad9eeSMircea Trofin
323b51e844fSMircea Trofin for (size_t I = 0; I < OutputSpecsSize; ++I) {
324b51e844fSMircea Trofin auto OutputSpec = GetOutputSpecs(I);
325caf395eeSMircea Trofin OutputFeed[I] = {
32671059257SMircea Trofin TF_GraphOperationByName(Graph.get(), (OutputSpec.name()).c_str()),
32771059257SMircea Trofin OutputSpec.port()};
32871059257SMircea Trofin if (!checkReportAndInvalidate(OutputFeed[I], OutputSpec))
329caf395eeSMircea Trofin return;
330caf395eeSMircea Trofin }
331caf395eeSMircea Trofin }
332caf395eeSMircea Trofin
TFModelEvaluator(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,function_ref<TensorSpec (size_t)> GetOutputSpecs,size_t OutputSpecsSize,const char * Tags)333b51e844fSMircea Trofin TFModelEvaluator::TFModelEvaluator(
334b51e844fSMircea Trofin StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
335b51e844fSMircea Trofin function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
336b51e844fSMircea Trofin const char *Tags)
337b51e844fSMircea Trofin : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, GetOutputSpecs,
338b51e844fSMircea Trofin OutputSpecsSize, Tags)) {
339b51e844fSMircea Trofin if (!Impl->isValid())
340b51e844fSMircea Trofin Impl.reset();
341b51e844fSMircea Trofin }
342b51e844fSMircea Trofin
TFModelEvaluator(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<TensorSpec> & OutputSpecs,const char * Tags)3434f763b21SMircea Trofin TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
34471059257SMircea Trofin const std::vector<TensorSpec> &InputSpecs,
34571059257SMircea Trofin const std::vector<TensorSpec> &OutputSpecs,
3464f763b21SMircea Trofin const char *Tags)
347b51e844fSMircea Trofin : TFModelEvaluator(
348b51e844fSMircea Trofin SavedModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I]; },
349b51e844fSMircea Trofin OutputSpecs.size(), Tags) {}
3504f763b21SMircea Trofin
~TFModelEvaluatorImpl()3514f763b21SMircea Trofin TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {
352caf395eeSMircea Trofin for (auto *T : Input) {
353caf395eeSMircea Trofin TF_DeleteTensor(T);
354caf395eeSMircea Trofin }
355caf395eeSMircea Trofin if (Session == nullptr)
356caf395eeSMircea Trofin return;
357caf395eeSMircea Trofin auto Status = createTFStatus();
358caf395eeSMircea Trofin TF_DeleteSession(Session, Status.get());
359caf395eeSMircea Trofin Session = nullptr;
360caf395eeSMircea Trofin if (TF_GetCode(Status.get()) != TF_Code::TF_OK)
361caf395eeSMircea Trofin errs() << "Could not delete TF session";
362caf395eeSMircea Trofin }
363caf395eeSMircea Trofin
checkReportAndInvalidate(const TF_Output & Output,const TensorSpec & OutputSpec)36471059257SMircea Trofin bool TFModelEvaluatorImpl::checkReportAndInvalidate(
36571059257SMircea Trofin const TF_Output &Output, const TensorSpec &OutputSpec) {
3664f763b21SMircea Trofin if (Output.oper)
3674f763b21SMircea Trofin return true;
36871059257SMircea Trofin errs() << "Could not find TF_Output named: " + OutputSpec.name();
3694f763b21SMircea Trofin IsValid = false;
3704f763b21SMircea Trofin return IsValid;
3714f763b21SMircea Trofin }
3724f763b21SMircea Trofin
evaluate()373caf395eeSMircea Trofin Optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
374caf395eeSMircea Trofin if (!isValid())
375caf395eeSMircea Trofin return None;
3764f763b21SMircea Trofin std::unique_ptr<EvaluationResultImpl> Ret =
3774f763b21SMircea Trofin std::make_unique<EvaluationResultImpl>(Impl->OutputSize());
378caf395eeSMircea Trofin auto Status = createTFStatus();
3794f763b21SMircea Trofin Impl->evaluate(Ret->getOutput().data(), Status.get());
380caf395eeSMircea Trofin if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
381caf395eeSMircea Trofin errs() << TF_Message(Status.get());
3824f763b21SMircea Trofin Impl.reset();
383caf395eeSMircea Trofin return None;
384caf395eeSMircea Trofin }
3854f763b21SMircea Trofin return EvaluationResult(std::move(Ret));
386caf395eeSMircea Trofin }
387caf395eeSMircea Trofin
initInput(size_t Index,TF_DataType Type,const std::vector<int64_t> & Dimensions)3884f763b21SMircea Trofin void TFModelEvaluatorImpl::initInput(size_t Index, TF_DataType Type,
389caf395eeSMircea Trofin const std::vector<int64_t> &Dimensions) {
390caf395eeSMircea Trofin int64_t TotalSize = TF_DataTypeSize(Type);
391caf395eeSMircea Trofin for (auto &D : Dimensions)
392caf395eeSMircea Trofin TotalSize *= D;
393caf395eeSMircea Trofin
394caf395eeSMircea Trofin Input[Index] =
395caf395eeSMircea Trofin TF_AllocateTensor(Type, Dimensions.data(), Dimensions.size(), TotalSize);
396caf395eeSMircea Trofin std::memset(TF_TensorData(Input[Index]), 0, TotalSize);
397caf395eeSMircea Trofin }
3984f763b21SMircea Trofin
getUntypedInput(size_t Index)3994f763b21SMircea Trofin void *TFModelEvaluator::getUntypedInput(size_t Index) {
400c35ad9eeSMircea Trofin if (Index < Impl->getInput().size())
4014f763b21SMircea Trofin return TF_TensorData(Impl->getInput()[Index]);
402c35ad9eeSMircea Trofin return nullptr;
4034f763b21SMircea Trofin }
4044f763b21SMircea Trofin
EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl)4054f763b21SMircea Trofin TFModelEvaluator::EvaluationResult::EvaluationResult(
4064f763b21SMircea Trofin std::unique_ptr<EvaluationResultImpl> Impl)
4074f763b21SMircea Trofin : Impl(std::move(Impl)) {}
4084f763b21SMircea Trofin
EvaluationResult(EvaluationResult && Other)4094f763b21SMircea Trofin TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
4104f763b21SMircea Trofin : Impl(std::move(Other.Impl)) {}
4114f763b21SMircea Trofin
412b18c41c6SMircea Trofin TFModelEvaluator::EvaluationResult &
operator =(EvaluationResult && Other)413b18c41c6SMircea Trofin TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
414b18c41c6SMircea Trofin Impl = std::move(Other.Impl);
415b18c41c6SMircea Trofin return *this;
416b18c41c6SMircea Trofin }
417b18c41c6SMircea Trofin
getUntypedTensorValue(size_t Index)4184f763b21SMircea Trofin void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
4194f763b21SMircea Trofin return TF_TensorData(Impl->getOutput()[Index]);
4204f763b21SMircea Trofin }
4214f763b21SMircea Trofin
422b18c41c6SMircea Trofin const void *
getUntypedTensorValue(size_t Index) const423b18c41c6SMircea Trofin TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
424b18c41c6SMircea Trofin return TF_TensorData(Impl->getOutput()[Index]);
425b18c41c6SMircea Trofin }
426b18c41c6SMircea Trofin
~EvaluationResult()4274f763b21SMircea Trofin TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
~TFModelEvaluator()4284f763b21SMircea Trofin TFModelEvaluator::~TFModelEvaluator() {}
42936bb1fb1SMircea Trofin
Logger(const std::vector<LoggedFeatureSpec> & FeatureSpecs,const TensorSpec & RewardSpec,bool IncludeReward)43055e12f70SMircea Trofin Logger::Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs,
43155e12f70SMircea Trofin const TensorSpec &RewardSpec, bool IncludeReward)
43255e12f70SMircea Trofin : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
43355e12f70SMircea Trofin IncludeReward(IncludeReward),
43455e12f70SMircea Trofin LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
43555e12f70SMircea Trofin IncludeReward)) {}
43655e2d206SMircea Trofin
~Logger()43755e12f70SMircea Trofin Logger::~Logger() {}
43836bb1fb1SMircea Trofin
43955e12f70SMircea Trofin #define LOG_REWARD(NAME, TYPE) \
44055e12f70SMircea Trofin void Logger::log##NAME##Reward(TYPE Value) { \
44155e12f70SMircea Trofin assert(IncludeReward); \
44255e12f70SMircea Trofin LoggerData->logReward(Value); \
44355e2d206SMircea Trofin }
44455e12f70SMircea Trofin
LOG_REWARD(Float,float)44555e12f70SMircea Trofin LOG_REWARD(Float, float)
44655e12f70SMircea Trofin LOG_REWARD(Int32, int32_t)
44755e12f70SMircea Trofin LOG_REWARD(Int64, int64_t)
44855e12f70SMircea Trofin #undef LOG_REWARD
44955e12f70SMircea Trofin
45055e12f70SMircea Trofin #define LOG_FINAL_REWARD(NAME, TYPE) \
45155e12f70SMircea Trofin void Logger::log##NAME##FinalReward(TYPE Value) { \
45255e12f70SMircea Trofin assert(RewardSpec.isElementType<TYPE>()); \
45355e12f70SMircea Trofin for (size_t I = 1; I < LoggerData->getNrRecords(); ++I) \
45455e12f70SMircea Trofin log##NAME##Reward(0); \
45555e12f70SMircea Trofin log##NAME##Reward(Value); \
45636bb1fb1SMircea Trofin }
45755e12f70SMircea Trofin
45855e12f70SMircea Trofin LOG_FINAL_REWARD(Float, float)
45955e12f70SMircea Trofin LOG_FINAL_REWARD(Int32, int32_t)
46055e12f70SMircea Trofin LOG_FINAL_REWARD(Int64, int64_t)
46155e12f70SMircea Trofin #undef LOG_FINAL_REWARD
46255e12f70SMircea Trofin
46355e12f70SMircea Trofin void Logger::logFloatValue(size_t FeatureID, const float *Value) {
46455e12f70SMircea Trofin assert(FeatureSpecs[FeatureID].Spec.isElementType<float>());
46555e12f70SMircea Trofin logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
46655e12f70SMircea Trofin }
46755e12f70SMircea Trofin
logInt64Value(size_t FeatureID,const int64_t * Value)46855e12f70SMircea Trofin void Logger::logInt64Value(size_t FeatureID, const int64_t *Value) {
46955e12f70SMircea Trofin assert(FeatureSpecs[FeatureID].Spec.isElementType<int64_t>());
47055e12f70SMircea Trofin logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
47155e12f70SMircea Trofin }
47255e12f70SMircea Trofin
logInt32Value(size_t FeatureID,const int32_t * Value)47355e12f70SMircea Trofin void Logger::logInt32Value(size_t FeatureID, const int32_t *Value) {
47455e12f70SMircea Trofin assert(FeatureSpecs[FeatureID].Spec.isElementType<int32_t>());
47555e12f70SMircea Trofin logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
47655e12f70SMircea Trofin }
47755e12f70SMircea Trofin
logSpecifiedTensorValue(size_t FeatureID,const char * RawData)47855e12f70SMircea Trofin void Logger::logSpecifiedTensorValue(size_t FeatureID, const char *RawData) {
47955e12f70SMircea Trofin const auto &Spec = FeatureSpecs[FeatureID].Spec;
48055e12f70SMircea Trofin char *Buff = addEntryAndGetFloatOrInt64Buffer(FeatureID);
48155e12f70SMircea Trofin if (Spec.isElementType<int32_t>())
48255e12f70SMircea Trofin for (size_t I = 0; I < Spec.getElementCount(); ++I)
48355e12f70SMircea Trofin (reinterpret_cast<int64_t *>(Buff))[I] =
48455e12f70SMircea Trofin static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData))[I]);
48555e12f70SMircea Trofin else if (Spec.isElementType<int64_t>() || Spec.isElementType<float>())
48655e12f70SMircea Trofin std::memcpy(Buff, RawData,
48755e12f70SMircea Trofin Spec.getElementCount() * Spec.getElementByteSize());
48855e12f70SMircea Trofin else
48955e12f70SMircea Trofin llvm_unreachable("Unsupported tensor type");
49055e12f70SMircea Trofin }
49155e12f70SMircea Trofin
addEntryAndGetFloatOrInt64Buffer(size_t FeatureID)49255e12f70SMircea Trofin char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) {
49355e12f70SMircea Trofin return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID));
49455e12f70SMircea Trofin }
49555e12f70SMircea Trofin
flush(std::string * Str)4961f5dceb1SMircea Trofin void Logger::flush(std::string *Str) { LoggerData->flush(Str); }
4971f5dceb1SMircea Trofin
flush(raw_ostream & OS)4981f5dceb1SMircea Trofin void Logger::flush(raw_ostream &OS) {
4991f5dceb1SMircea Trofin std::string Buff;
5001f5dceb1SMircea Trofin LoggerData->flush(&Buff);
5011f5dceb1SMircea Trofin OS << Buff;
5021f5dceb1SMircea Trofin }
5031f5dceb1SMircea Trofin
flushLogs(raw_ostream & OS,const StringMap<std::unique_ptr<Logger>> & Loggers)5041f5dceb1SMircea Trofin void Logger::flushLogs(raw_ostream &OS,
5051f5dceb1SMircea Trofin const StringMap<std::unique_ptr<Logger>> &Loggers) {
5061f5dceb1SMircea Trofin google::protobuf::Struct Msg;
5071f5dceb1SMircea Trofin for (const auto &NamedLogger : Loggers) {
5081f5dceb1SMircea Trofin tensorflow::SequenceExample SE;
5091f5dceb1SMircea Trofin const auto &Logger = NamedLogger.second;
5101f5dceb1SMircea Trofin std::string Unencoded;
5111f5dceb1SMircea Trofin if (Logger->LoggerData->getNrRecords() > 0)
5121f5dceb1SMircea Trofin Logger->flush(&Unencoded);
5131f5dceb1SMircea Trofin
5141f5dceb1SMircea Trofin (*Msg.mutable_fields())[NamedLogger.first().str()]
5151f5dceb1SMircea Trofin .mutable_string_value()
5161f5dceb1SMircea Trofin ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded));
5171f5dceb1SMircea Trofin }
5181f5dceb1SMircea Trofin
5191f5dceb1SMircea Trofin std::string OutStr;
5201f5dceb1SMircea Trofin serialize(Msg, &OutStr);
5211f5dceb1SMircea Trofin OS << OutStr;
5221f5dceb1SMircea Trofin }
5234fe912f1SNico Weber #endif // defined(LLVM_HAVE_TF_API)
524