1caf395eeSMircea Trofin //===- TFUtils.cpp - tensorflow evaluation utilities ----------------------===//
2caf395eeSMircea Trofin //
3c874dd53SChristopher Di Bella // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4c874dd53SChristopher Di Bella // See https://llvm.org/LICENSE.txt for license information.
5c874dd53SChristopher Di Bella // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6caf395eeSMircea Trofin //
7caf395eeSMircea Trofin //===----------------------------------------------------------------------===//
8caf395eeSMircea Trofin //
9caf395eeSMircea Trofin // This file implements utilities for interfacing with tensorflow C APIs.
10caf395eeSMircea Trofin //
11caf395eeSMircea Trofin //===----------------------------------------------------------------------===//
124fe912f1SNico Weber #include "llvm/Config/config.h"
134fe912f1SNico Weber #if defined(LLVM_HAVE_TF_API)
14caf395eeSMircea Trofin 
15caf395eeSMircea Trofin #include "llvm/ADT/Twine.h"
164b1b109cSMircea Trofin #include "llvm/Analysis/Utils/TFUtils.h"
171f5dceb1SMircea Trofin #include "llvm/Support/Base64.h"
1855e2d206SMircea Trofin #include "llvm/Support/CommandLine.h"
19caf395eeSMircea Trofin #include "llvm/Support/Debug.h"
204b1b109cSMircea Trofin #include "llvm/Support/JSON.h"
21b51e844fSMircea Trofin #include "llvm/Support/MemoryBuffer.h"
228ab2353aSMircea Trofin #include "llvm/Support/Path.h"
23caf395eeSMircea Trofin #include "llvm/Support/raw_ostream.h"
24caf395eeSMircea Trofin 
251f5dceb1SMircea Trofin #include "google/protobuf/struct.pb.h"
2655e2d206SMircea Trofin #include "google/protobuf/text_format.h"
274f763b21SMircea Trofin #include "tensorflow/c/c_api.h"
28caf395eeSMircea Trofin #include "tensorflow/c/c_api_experimental.h"
2955e2d206SMircea Trofin #include "tensorflow/core/example/example.pb.h"
30caf395eeSMircea Trofin #include <cassert>
3190b9c49cSMircea Trofin #include <numeric>
32caf395eeSMircea Trofin 
33caf395eeSMircea Trofin using namespace llvm;
34caf395eeSMircea Trofin 
3555e12f70SMircea Trofin using google::protobuf::Message;
3655e12f70SMircea Trofin using google::protobuf::TextFormat;
3755e12f70SMircea Trofin 
3855e2d206SMircea Trofin static cl::opt<bool>
3955e2d206SMircea Trofin     ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden,
4055e2d206SMircea Trofin                      cl::desc("Output textual (human-readable) protobuf."));
4155e2d206SMircea Trofin 
42caf395eeSMircea Trofin namespace {
43caf395eeSMircea Trofin 
444f763b21SMircea Trofin using TFGraphPtr = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>;
454f763b21SMircea Trofin using TFSessionOptionsPtr =
464f763b21SMircea Trofin     std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>;
474f763b21SMircea Trofin using TFStatusPtr = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
484f763b21SMircea Trofin 
49caf395eeSMircea Trofin struct TFInitializer {
TFInitializer__anon5e75051e0111::TFInitializer50caf395eeSMircea Trofin   TFInitializer() {
51caf395eeSMircea Trofin     int Argc = 1;
52caf395eeSMircea Trofin     const char *Name = "";
53caf395eeSMircea Trofin     const char **NamePtr = &Name;
54caf395eeSMircea Trofin     TF_InitMain(Name, &Argc, const_cast<char ***>(&NamePtr));
55caf395eeSMircea Trofin   }
56caf395eeSMircea Trofin };
57caf395eeSMircea Trofin 
ensureInitTF()58*ede60037SNicolai Hähnle bool ensureInitTF() {
59*ede60037SNicolai Hähnle   static TFInitializer TFLibInitializer;
60*ede60037SNicolai Hähnle   return true;
61*ede60037SNicolai Hähnle }
62caf395eeSMircea Trofin 
createTFGraph()634f763b21SMircea Trofin TFGraphPtr createTFGraph() {
644f763b21SMircea Trofin   return TFGraphPtr(TF_NewGraph(), &TF_DeleteGraph);
65caf395eeSMircea Trofin }
66caf395eeSMircea Trofin 
createTFStatus()674f763b21SMircea Trofin TFStatusPtr createTFStatus() {
684f763b21SMircea Trofin   return TFStatusPtr(TF_NewStatus(), &TF_DeleteStatus);
69caf395eeSMircea Trofin }
70caf395eeSMircea Trofin 
createTFSessionOptions()714f763b21SMircea Trofin TFSessionOptionsPtr createTFSessionOptions() {
724f763b21SMircea Trofin   return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
73caf395eeSMircea Trofin }
741f5dceb1SMircea Trofin 
serialize(const Message & SE,std::string * OutStr)751f5dceb1SMircea Trofin void serialize(const Message &SE, std::string *OutStr) {
761f5dceb1SMircea Trofin   if (ProtobufTextMode) {
771f5dceb1SMircea Trofin     TextFormat::PrintToString(SE, OutStr);
781f5dceb1SMircea Trofin   } else {
791f5dceb1SMircea Trofin     *OutStr = SE.SerializeAsString();
801f5dceb1SMircea Trofin   }
811f5dceb1SMircea Trofin }
82e4794ff5SMircea Trofin 
getTFTypeIndex(TensorType TType)83e4794ff5SMircea Trofin int getTFTypeIndex(TensorType TType) {
84e4794ff5SMircea Trofin   switch (TType) {
85e4794ff5SMircea Trofin   case TensorType::Double:
86e4794ff5SMircea Trofin     return TF_DOUBLE;
87e4794ff5SMircea Trofin   case TensorType::Float:
88e4794ff5SMircea Trofin     return TF_FLOAT;
89e4794ff5SMircea Trofin   case TensorType::Int8:
90e4794ff5SMircea Trofin     return TF_INT8;
91e4794ff5SMircea Trofin   case TensorType::UInt8:
92e4794ff5SMircea Trofin     return TF_UINT8;
93e4794ff5SMircea Trofin   case TensorType::Int16:
94e4794ff5SMircea Trofin     return TF_INT16;
95e4794ff5SMircea Trofin   case TensorType::UInt16:
96e4794ff5SMircea Trofin     return TF_UINT16;
97e4794ff5SMircea Trofin   case TensorType::Int32:
98e4794ff5SMircea Trofin     return TF_INT32;
99e4794ff5SMircea Trofin   case TensorType::UInt32:
100e4794ff5SMircea Trofin     return TF_UINT32;
101e4794ff5SMircea Trofin   case TensorType::Int64:
102e4794ff5SMircea Trofin     return TF_INT64;
103e4794ff5SMircea Trofin   case TensorType::UInt64:
104e4794ff5SMircea Trofin     return TF_UINT64;
105e4794ff5SMircea Trofin   case TensorType::Invalid:
106e4794ff5SMircea Trofin     llvm_unreachable("Unknown tensor type");
107e4794ff5SMircea Trofin   }
108e4794ff5SMircea Trofin }
109caf395eeSMircea Trofin } // namespace
110caf395eeSMircea Trofin 
1114f763b21SMircea Trofin namespace llvm {
1124f763b21SMircea Trofin class EvaluationResultImpl {
1134f763b21SMircea Trofin public:
EvaluationResultImpl(size_t OutputSize)1144f763b21SMircea Trofin   EvaluationResultImpl(size_t OutputSize)
1154f763b21SMircea Trofin       : OutputSize(OutputSize), Output(OutputSize){};
1164f763b21SMircea Trofin 
~EvaluationResultImpl()1174f763b21SMircea Trofin   ~EvaluationResultImpl() {
1184f763b21SMircea Trofin     for (auto *P : Output)
1194f763b21SMircea Trofin       if (P)
1204f763b21SMircea Trofin         TF_DeleteTensor(P);
1214f763b21SMircea Trofin   }
1224f763b21SMircea Trofin 
1234f763b21SMircea Trofin   EvaluationResultImpl(const EvaluationResultImpl &) = delete;
1244f763b21SMircea Trofin   EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
getOutput()1254f763b21SMircea Trofin   std::vector<TF_Tensor *> &getOutput() { return Output; }
1264f763b21SMircea Trofin 
1274f763b21SMircea Trofin private:
1284f763b21SMircea Trofin   const size_t OutputSize;
1294f763b21SMircea Trofin   std::vector<TF_Tensor *> Output;
1304f763b21SMircea Trofin };
1314f763b21SMircea Trofin 
1324f763b21SMircea Trofin class TFModelEvaluatorImpl {
1334f763b21SMircea Trofin public:
1344f763b21SMircea Trofin   TFModelEvaluatorImpl(StringRef SavedModelPath,
13571059257SMircea Trofin                        const std::vector<TensorSpec> &InputSpecs,
136b51e844fSMircea Trofin                        function_ref<TensorSpec(size_t)> GetOutputSpecs,
137b51e844fSMircea Trofin                        size_t OutputSpecsSize, const char *Tags);
1384f763b21SMircea Trofin 
isValid() const1394f763b21SMircea Trofin   bool isValid() const { return IsValid; }
OutputSize() const1404f763b21SMircea Trofin   size_t OutputSize() const { return OutputFeed.size(); }
1414f763b21SMircea Trofin 
evaluate(TF_Tensor ** Output,TF_Status * Status)1424f763b21SMircea Trofin   void evaluate(TF_Tensor **Output, TF_Status *Status) {
1434f763b21SMircea Trofin     TF_SessionRun(Session, nullptr, InputFeed.data(), Input.data(),
1444f763b21SMircea Trofin                   Input.size(), OutputFeed.data(), Output, OutputFeed.size(),
1454f763b21SMircea Trofin                   nullptr, 0, nullptr, Status);
1464f763b21SMircea Trofin   }
1474f763b21SMircea Trofin 
1484f763b21SMircea Trofin   void initInput(size_t Index, TF_DataType Type,
1494f763b21SMircea Trofin                  const std::vector<int64_t> &Dimensions);
getInput() const1504f763b21SMircea Trofin   const std::vector<TF_Tensor *> &getInput() const { return Input; }
1514f763b21SMircea Trofin 
1524f763b21SMircea Trofin   ~TFModelEvaluatorImpl();
1534f763b21SMircea Trofin 
1544f763b21SMircea Trofin private:
1554f763b21SMircea Trofin   /// The objects necessary for carrying out an evaluation of the SavedModel.
1564f763b21SMircea Trofin   /// They are expensive to set up, and we maintain them accross all the
1574f763b21SMircea Trofin   /// evaluations of the model.
1584f763b21SMircea Trofin   TF_Session *Session = nullptr;
1594f763b21SMircea Trofin   TFGraphPtr Graph;
1604f763b21SMircea Trofin   TFSessionOptionsPtr Options;
1614f763b21SMircea Trofin 
1624f763b21SMircea Trofin   /// The specification of the input nodes.
1634f763b21SMircea Trofin   std::vector<TF_Output> InputFeed;
1644f763b21SMircea Trofin 
1654f763b21SMircea Trofin   /// The input tensors. They must match by index of the corresponding InputFeed
1664f763b21SMircea Trofin   /// value. We set up the tensors once and just mutate theirs scalars before
1674f763b21SMircea Trofin   /// each evaluation. The input tensors keep their value after an evaluation.
1684f763b21SMircea Trofin   std::vector<TF_Tensor *> Input;
1694f763b21SMircea Trofin 
1704f763b21SMircea Trofin   /// The specification of the output nodes. When evaluating, the tensors in the
1714f763b21SMircea Trofin   /// output tensor vector must match by index the corresponding element in the
1724f763b21SMircea Trofin   /// OutputFeed.
1734f763b21SMircea Trofin   std::vector<TF_Output> OutputFeed;
1744f763b21SMircea Trofin 
invalidate()1754f763b21SMircea Trofin   void invalidate() { IsValid = false; }
1764f763b21SMircea Trofin 
1774f763b21SMircea Trofin   bool IsValid = true;
1784f763b21SMircea Trofin 
1794f763b21SMircea Trofin   /// Reusable utility for ensuring we can bind the requested Name to a node in
1804f763b21SMircea Trofin   /// the SavedModel Graph.
18171059257SMircea Trofin   bool checkReportAndInvalidate(const TF_Output &Output,
18271059257SMircea Trofin                                 const TensorSpec &OutputSpec);
1834f763b21SMircea Trofin };
18455e12f70SMircea Trofin 
18555e12f70SMircea Trofin class LoggerDataImpl {
18655e12f70SMircea Trofin   const std::vector<LoggedFeatureSpec> LoggedFeatureSpecs;
18755e12f70SMircea Trofin   const TensorSpec RewardSpec;
188ae1a2a09SMircea Trofin   const bool IncludeReward;
18955e12f70SMircea Trofin 
190ae1a2a09SMircea Trofin   std::vector<tensorflow::FeatureList> FeatureLists;
191ae1a2a09SMircea Trofin   tensorflow::FeatureList Reward;
192ae1a2a09SMircea Trofin 
isSelfConsistent(const tensorflow::SequenceExample & SE,size_t NrRecords) const193ae1a2a09SMircea Trofin   bool isSelfConsistent(const tensorflow::SequenceExample &SE,
194ae1a2a09SMircea Trofin                         size_t NrRecords) const {
195ae1a2a09SMircea Trofin     bool Ret = true;
196ae1a2a09SMircea Trofin     for (const auto &TSpecs : LoggedFeatureSpecs) {
197ae1a2a09SMircea Trofin       const auto &Name = TSpecs.getLoggingName();
198ae1a2a09SMircea Trofin       const auto &FL = SE.feature_lists().feature_list().at(Name).feature();
199ae1a2a09SMircea Trofin       if (NrRecords != static_cast<size_t>(FL.size())) {
200ae1a2a09SMircea Trofin         dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected "
201ae1a2a09SMircea Trofin                << NrRecords << " got " << FL.size() << "\n";
202ae1a2a09SMircea Trofin         Ret = false;
203ae1a2a09SMircea Trofin       }
204ae1a2a09SMircea Trofin     }
205ae1a2a09SMircea Trofin     if (IncludeReward && static_cast<size_t>(SE.feature_lists()
206ae1a2a09SMircea Trofin                                                  .feature_list()
207ae1a2a09SMircea Trofin                                                  .at(RewardSpec.name())
208ae1a2a09SMircea Trofin                                                  .feature()
209ae1a2a09SMircea Trofin                                                  .size()) != NrRecords) {
210ae1a2a09SMircea Trofin       dbgs() << "[TF-UTILS]: reward is missing records.\n";
211ae1a2a09SMircea Trofin       Ret = false;
212ae1a2a09SMircea Trofin     }
213ae1a2a09SMircea Trofin     return Ret;
214ae1a2a09SMircea Trofin   }
215ae1a2a09SMircea Trofin 
transferLog(tensorflow::SequenceExample & SE)216ae1a2a09SMircea Trofin   void transferLog(tensorflow::SequenceExample &SE) {
217ae1a2a09SMircea Trofin     auto *FL = SE.mutable_feature_lists()->mutable_feature_list();
218ae1a2a09SMircea Trofin     if (IncludeReward)
2198dc3fe0cSMircea Trofin       (*FL)[RewardSpec.name()] = std::move(Reward);
220ae1a2a09SMircea Trofin     assert(FeatureLists.size() == LoggedFeatureSpecs.size());
221ae1a2a09SMircea Trofin     for (size_t I = 0; I < FeatureLists.size(); ++I) {
222ae1a2a09SMircea Trofin       const auto &LFS = LoggedFeatureSpecs[I];
2238dc3fe0cSMircea Trofin       (*FL)[LFS.getLoggingName()] = std::move(FeatureLists[I]);
224ae1a2a09SMircea Trofin     }
225ae1a2a09SMircea Trofin   }
22655e12f70SMircea Trofin 
22755e12f70SMircea Trofin public:
LoggerDataImpl(const std::vector<LoggedFeatureSpec> & LoggedSpecs,const TensorSpec & RewardSpec,bool IncludeReward)22855e12f70SMircea Trofin   LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs,
22955e12f70SMircea Trofin                  const TensorSpec &RewardSpec, bool IncludeReward)
230ae1a2a09SMircea Trofin       : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
231ae1a2a09SMircea Trofin         IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
23255e12f70SMircea Trofin 
233ae1a2a09SMircea Trofin   // flush the logged info to a stream and clear the log contents.
flush(std::string * Str)2341f5dceb1SMircea Trofin   void flush(std::string *Str) {
235ae1a2a09SMircea Trofin     size_t NrRecords = getNrRecords();
236510402c2SMircea Trofin     (void)NrRecords;
237ae1a2a09SMircea Trofin     tensorflow::SequenceExample SE;
238ae1a2a09SMircea Trofin     transferLog(SE);
239ae1a2a09SMircea Trofin     assert(isSelfConsistent(SE, NrRecords));
2401f5dceb1SMircea Trofin     serialize(SE, Str);
24155e12f70SMircea Trofin   }
24255e12f70SMircea Trofin 
addNewTensor(size_t FeatureID)24355e12f70SMircea Trofin   char *addNewTensor(size_t FeatureID) {
24455e12f70SMircea Trofin     const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec;
24555e12f70SMircea Trofin     if (Spec.isElementType<float>()) {
24655e12f70SMircea Trofin       auto *RF = FeatureLists[FeatureID]
247ae1a2a09SMircea Trofin                      .add_feature()
24855e12f70SMircea Trofin                      ->mutable_float_list()
24955e12f70SMircea Trofin                      ->mutable_value();
25055e12f70SMircea Trofin       RF->Resize(Spec.getElementCount(), 0.0);
25155e12f70SMircea Trofin       return reinterpret_cast<char *>(RF->mutable_data());
25255e12f70SMircea Trofin     } else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) {
25355e12f70SMircea Trofin       auto *RF = FeatureLists[FeatureID]
254ae1a2a09SMircea Trofin                      .add_feature()
25555e12f70SMircea Trofin                      ->mutable_int64_list()
25655e12f70SMircea Trofin                      ->mutable_value();
25755e12f70SMircea Trofin       RF->Resize(Spec.getElementCount(), 0);
25855e12f70SMircea Trofin       return reinterpret_cast<char *>(RF->mutable_data());
25955e12f70SMircea Trofin     }
26055e12f70SMircea Trofin     llvm_unreachable("Unsupported tensor type.");
26155e12f70SMircea Trofin   }
26255e12f70SMircea Trofin 
logReward(T Value)26355e12f70SMircea Trofin   template <typename T> void logReward(T Value) {
264ae1a2a09SMircea Trofin     assert(IncludeReward);
26555e12f70SMircea Trofin     if (RewardSpec.isElementType<float>())
266ae1a2a09SMircea Trofin       Reward.add_feature()->mutable_float_list()->add_value(Value);
26755e12f70SMircea Trofin     else if (RewardSpec.isElementType<int32_t>() ||
26855e12f70SMircea Trofin              RewardSpec.isElementType<int64_t>())
269ae1a2a09SMircea Trofin       Reward.add_feature()->mutable_int64_list()->add_value(Value);
27055e12f70SMircea Trofin     else
27155e12f70SMircea Trofin       llvm_unreachable("Unsupported tensor type.");
27255e12f70SMircea Trofin   }
27355e12f70SMircea Trofin 
getNrRecords() const27455e12f70SMircea Trofin   size_t getNrRecords() const {
275ae1a2a09SMircea Trofin     return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
27655e12f70SMircea Trofin   }
27755e12f70SMircea Trofin };
2784f763b21SMircea Trofin } // namespace llvm
2794f763b21SMircea Trofin 
TFModelEvaluatorImpl(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,function_ref<TensorSpec (size_t)> GetOutputSpecs,size_t OutputSpecsSize,const char * Tags="serve")2804f763b21SMircea Trofin TFModelEvaluatorImpl::TFModelEvaluatorImpl(
28171059257SMircea Trofin     StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
282b51e844fSMircea Trofin     function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
283b51e844fSMircea Trofin     const char *Tags = "serve")
284caf395eeSMircea Trofin     : Graph(createTFGraph()), Options(createTFSessionOptions()),
28571059257SMircea Trofin       InputFeed(InputSpecs.size()), Input(InputSpecs.size()),
286b51e844fSMircea Trofin       OutputFeed(OutputSpecsSize) {
287caf395eeSMircea Trofin   if (!ensureInitTF()) {
288caf395eeSMircea Trofin     errs() << "Tensorflow should have been initialized";
289caf395eeSMircea Trofin     return;
290caf395eeSMircea Trofin   }
291caf395eeSMircea Trofin   auto Status = createTFStatus();
292caf395eeSMircea Trofin 
293caf395eeSMircea Trofin   Session = TF_LoadSessionFromSavedModel(Options.get(), nullptr,
294caf395eeSMircea Trofin                                          SavedModelPath.str().c_str(), &Tags, 1,
295caf395eeSMircea Trofin                                          Graph.get(), nullptr, Status.get());
296caf395eeSMircea Trofin   if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
297caf395eeSMircea Trofin     errs() << TF_Message(Status.get());
2984f763b21SMircea Trofin     invalidate();
299caf395eeSMircea Trofin   }
300c35ad9eeSMircea Trofin   size_t NrSupported = 0;
30171059257SMircea Trofin   for (size_t I = 0; I < InputSpecs.size(); ++I) {
30271059257SMircea Trofin     auto &InputSpec = InputSpecs[I];
303caf395eeSMircea Trofin     InputFeed[I] = {
30471059257SMircea Trofin         TF_GraphOperationByName(Graph.get(), (InputSpec.name()).c_str()),
30571059257SMircea Trofin         InputSpec.port()};
306c35ad9eeSMircea Trofin     if (!InputFeed[I].oper) {
307c35ad9eeSMircea Trofin       continue;
308c35ad9eeSMircea Trofin     }
309c35ad9eeSMircea Trofin     if (NrSupported++ != I) {
310c35ad9eeSMircea Trofin       errs()
311c35ad9eeSMircea Trofin           << "Unsupported features must be placed at the end of the InputSpecs";
312c35ad9eeSMircea Trofin       invalidate();
313c35ad9eeSMircea Trofin       return;
314c35ad9eeSMircea Trofin     }
31571059257SMircea Trofin     if (!checkReportAndInvalidate(InputFeed[I], InputSpec))
316caf395eeSMircea Trofin       return;
317e4794ff5SMircea Trofin     initInput(I, static_cast<TF_DataType>(getTFTypeIndex(InputSpec.type())),
31871059257SMircea Trofin               InputSpec.shape());
319caf395eeSMircea Trofin   }
320c35ad9eeSMircea Trofin   InputFeed.resize(NrSupported);
321c35ad9eeSMircea Trofin   Input.resize(NrSupported);
322c35ad9eeSMircea Trofin 
323b51e844fSMircea Trofin   for (size_t I = 0; I < OutputSpecsSize; ++I) {
324b51e844fSMircea Trofin     auto OutputSpec = GetOutputSpecs(I);
325caf395eeSMircea Trofin     OutputFeed[I] = {
32671059257SMircea Trofin         TF_GraphOperationByName(Graph.get(), (OutputSpec.name()).c_str()),
32771059257SMircea Trofin         OutputSpec.port()};
32871059257SMircea Trofin     if (!checkReportAndInvalidate(OutputFeed[I], OutputSpec))
329caf395eeSMircea Trofin       return;
330caf395eeSMircea Trofin   }
331caf395eeSMircea Trofin }
332caf395eeSMircea Trofin 
TFModelEvaluator(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,function_ref<TensorSpec (size_t)> GetOutputSpecs,size_t OutputSpecsSize,const char * Tags)333b51e844fSMircea Trofin TFModelEvaluator::TFModelEvaluator(
334b51e844fSMircea Trofin     StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
335b51e844fSMircea Trofin     function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
336b51e844fSMircea Trofin     const char *Tags)
337b51e844fSMircea Trofin     : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, GetOutputSpecs,
338b51e844fSMircea Trofin                                     OutputSpecsSize, Tags)) {
339b51e844fSMircea Trofin   if (!Impl->isValid())
340b51e844fSMircea Trofin     Impl.reset();
341b51e844fSMircea Trofin }
342b51e844fSMircea Trofin 
TFModelEvaluator(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<TensorSpec> & OutputSpecs,const char * Tags)3434f763b21SMircea Trofin TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
34471059257SMircea Trofin                                    const std::vector<TensorSpec> &InputSpecs,
34571059257SMircea Trofin                                    const std::vector<TensorSpec> &OutputSpecs,
3464f763b21SMircea Trofin                                    const char *Tags)
347b51e844fSMircea Trofin     : TFModelEvaluator(
348b51e844fSMircea Trofin           SavedModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I]; },
349b51e844fSMircea Trofin           OutputSpecs.size(), Tags) {}
3504f763b21SMircea Trofin 
~TFModelEvaluatorImpl()3514f763b21SMircea Trofin TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {
352caf395eeSMircea Trofin   for (auto *T : Input) {
353caf395eeSMircea Trofin     TF_DeleteTensor(T);
354caf395eeSMircea Trofin   }
355caf395eeSMircea Trofin   if (Session == nullptr)
356caf395eeSMircea Trofin     return;
357caf395eeSMircea Trofin   auto Status = createTFStatus();
358caf395eeSMircea Trofin   TF_DeleteSession(Session, Status.get());
359caf395eeSMircea Trofin   Session = nullptr;
360caf395eeSMircea Trofin   if (TF_GetCode(Status.get()) != TF_Code::TF_OK)
361caf395eeSMircea Trofin     errs() << "Could not delete TF session";
362caf395eeSMircea Trofin }
363caf395eeSMircea Trofin 
checkReportAndInvalidate(const TF_Output & Output,const TensorSpec & OutputSpec)36471059257SMircea Trofin bool TFModelEvaluatorImpl::checkReportAndInvalidate(
36571059257SMircea Trofin     const TF_Output &Output, const TensorSpec &OutputSpec) {
3664f763b21SMircea Trofin   if (Output.oper)
3674f763b21SMircea Trofin     return true;
36871059257SMircea Trofin   errs() << "Could not find TF_Output named: " + OutputSpec.name();
3694f763b21SMircea Trofin   IsValid = false;
3704f763b21SMircea Trofin   return IsValid;
3714f763b21SMircea Trofin }
3724f763b21SMircea Trofin 
evaluate()373caf395eeSMircea Trofin Optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
374caf395eeSMircea Trofin   if (!isValid())
375caf395eeSMircea Trofin     return None;
3764f763b21SMircea Trofin   std::unique_ptr<EvaluationResultImpl> Ret =
3774f763b21SMircea Trofin       std::make_unique<EvaluationResultImpl>(Impl->OutputSize());
378caf395eeSMircea Trofin   auto Status = createTFStatus();
3794f763b21SMircea Trofin   Impl->evaluate(Ret->getOutput().data(), Status.get());
380caf395eeSMircea Trofin   if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
381caf395eeSMircea Trofin     errs() << TF_Message(Status.get());
3824f763b21SMircea Trofin     Impl.reset();
383caf395eeSMircea Trofin     return None;
384caf395eeSMircea Trofin   }
3854f763b21SMircea Trofin   return EvaluationResult(std::move(Ret));
386caf395eeSMircea Trofin }
387caf395eeSMircea Trofin 
initInput(size_t Index,TF_DataType Type,const std::vector<int64_t> & Dimensions)3884f763b21SMircea Trofin void TFModelEvaluatorImpl::initInput(size_t Index, TF_DataType Type,
389caf395eeSMircea Trofin                                      const std::vector<int64_t> &Dimensions) {
390caf395eeSMircea Trofin   int64_t TotalSize = TF_DataTypeSize(Type);
391caf395eeSMircea Trofin   for (auto &D : Dimensions)
392caf395eeSMircea Trofin     TotalSize *= D;
393caf395eeSMircea Trofin 
394caf395eeSMircea Trofin   Input[Index] =
395caf395eeSMircea Trofin       TF_AllocateTensor(Type, Dimensions.data(), Dimensions.size(), TotalSize);
396caf395eeSMircea Trofin   std::memset(TF_TensorData(Input[Index]), 0, TotalSize);
397caf395eeSMircea Trofin }
3984f763b21SMircea Trofin 
getUntypedInput(size_t Index)3994f763b21SMircea Trofin void *TFModelEvaluator::getUntypedInput(size_t Index) {
400c35ad9eeSMircea Trofin   if (Index < Impl->getInput().size())
4014f763b21SMircea Trofin     return TF_TensorData(Impl->getInput()[Index]);
402c35ad9eeSMircea Trofin   return nullptr;
4034f763b21SMircea Trofin }
4044f763b21SMircea Trofin 
EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl)4054f763b21SMircea Trofin TFModelEvaluator::EvaluationResult::EvaluationResult(
4064f763b21SMircea Trofin     std::unique_ptr<EvaluationResultImpl> Impl)
4074f763b21SMircea Trofin     : Impl(std::move(Impl)) {}
4084f763b21SMircea Trofin 
EvaluationResult(EvaluationResult && Other)4094f763b21SMircea Trofin TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
4104f763b21SMircea Trofin     : Impl(std::move(Other.Impl)) {}
4114f763b21SMircea Trofin 
412b18c41c6SMircea Trofin TFModelEvaluator::EvaluationResult &
operator =(EvaluationResult && Other)413b18c41c6SMircea Trofin TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
414b18c41c6SMircea Trofin   Impl = std::move(Other.Impl);
415b18c41c6SMircea Trofin   return *this;
416b18c41c6SMircea Trofin }
417b18c41c6SMircea Trofin 
getUntypedTensorValue(size_t Index)4184f763b21SMircea Trofin void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
4194f763b21SMircea Trofin   return TF_TensorData(Impl->getOutput()[Index]);
4204f763b21SMircea Trofin }
4214f763b21SMircea Trofin 
422b18c41c6SMircea Trofin const void *
getUntypedTensorValue(size_t Index) const423b18c41c6SMircea Trofin TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
424b18c41c6SMircea Trofin   return TF_TensorData(Impl->getOutput()[Index]);
425b18c41c6SMircea Trofin }
426b18c41c6SMircea Trofin 
~EvaluationResult()4274f763b21SMircea Trofin TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
~TFModelEvaluator()4284f763b21SMircea Trofin TFModelEvaluator::~TFModelEvaluator() {}
42936bb1fb1SMircea Trofin 
Logger(const std::vector<LoggedFeatureSpec> & FeatureSpecs,const TensorSpec & RewardSpec,bool IncludeReward)43055e12f70SMircea Trofin Logger::Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs,
43155e12f70SMircea Trofin                const TensorSpec &RewardSpec, bool IncludeReward)
43255e12f70SMircea Trofin     : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
43355e12f70SMircea Trofin       IncludeReward(IncludeReward),
43455e12f70SMircea Trofin       LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
43555e12f70SMircea Trofin                                                   IncludeReward)) {}
43655e2d206SMircea Trofin 
~Logger()43755e12f70SMircea Trofin Logger::~Logger() {}
43836bb1fb1SMircea Trofin 
43955e12f70SMircea Trofin #define LOG_REWARD(NAME, TYPE)                                                 \
44055e12f70SMircea Trofin   void Logger::log##NAME##Reward(TYPE Value) {                                 \
44155e12f70SMircea Trofin     assert(IncludeReward);                                                     \
44255e12f70SMircea Trofin     LoggerData->logReward(Value);                                              \
44355e2d206SMircea Trofin   }
44455e12f70SMircea Trofin 
LOG_REWARD(Float,float)44555e12f70SMircea Trofin LOG_REWARD(Float, float)
44655e12f70SMircea Trofin LOG_REWARD(Int32, int32_t)
44755e12f70SMircea Trofin LOG_REWARD(Int64, int64_t)
44855e12f70SMircea Trofin #undef LOG_REWARD
44955e12f70SMircea Trofin 
45055e12f70SMircea Trofin #define LOG_FINAL_REWARD(NAME, TYPE)                                           \
45155e12f70SMircea Trofin   void Logger::log##NAME##FinalReward(TYPE Value) {                            \
45255e12f70SMircea Trofin     assert(RewardSpec.isElementType<TYPE>());                                  \
45355e12f70SMircea Trofin     for (size_t I = 1; I < LoggerData->getNrRecords(); ++I)                    \
45455e12f70SMircea Trofin       log##NAME##Reward(0);                                                    \
45555e12f70SMircea Trofin     log##NAME##Reward(Value);                                                  \
45636bb1fb1SMircea Trofin   }
45755e12f70SMircea Trofin 
45855e12f70SMircea Trofin LOG_FINAL_REWARD(Float, float)
45955e12f70SMircea Trofin LOG_FINAL_REWARD(Int32, int32_t)
46055e12f70SMircea Trofin LOG_FINAL_REWARD(Int64, int64_t)
46155e12f70SMircea Trofin #undef LOG_FINAL_REWARD
46255e12f70SMircea Trofin 
46355e12f70SMircea Trofin void Logger::logFloatValue(size_t FeatureID, const float *Value) {
46455e12f70SMircea Trofin   assert(FeatureSpecs[FeatureID].Spec.isElementType<float>());
46555e12f70SMircea Trofin   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
46655e12f70SMircea Trofin }
46755e12f70SMircea Trofin 
logInt64Value(size_t FeatureID,const int64_t * Value)46855e12f70SMircea Trofin void Logger::logInt64Value(size_t FeatureID, const int64_t *Value) {
46955e12f70SMircea Trofin   assert(FeatureSpecs[FeatureID].Spec.isElementType<int64_t>());
47055e12f70SMircea Trofin   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
47155e12f70SMircea Trofin }
47255e12f70SMircea Trofin 
logInt32Value(size_t FeatureID,const int32_t * Value)47355e12f70SMircea Trofin void Logger::logInt32Value(size_t FeatureID, const int32_t *Value) {
47455e12f70SMircea Trofin   assert(FeatureSpecs[FeatureID].Spec.isElementType<int32_t>());
47555e12f70SMircea Trofin   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
47655e12f70SMircea Trofin }
47755e12f70SMircea Trofin 
logSpecifiedTensorValue(size_t FeatureID,const char * RawData)47855e12f70SMircea Trofin void Logger::logSpecifiedTensorValue(size_t FeatureID, const char *RawData) {
47955e12f70SMircea Trofin   const auto &Spec = FeatureSpecs[FeatureID].Spec;
48055e12f70SMircea Trofin   char *Buff = addEntryAndGetFloatOrInt64Buffer(FeatureID);
48155e12f70SMircea Trofin   if (Spec.isElementType<int32_t>())
48255e12f70SMircea Trofin     for (size_t I = 0; I < Spec.getElementCount(); ++I)
48355e12f70SMircea Trofin       (reinterpret_cast<int64_t *>(Buff))[I] =
48455e12f70SMircea Trofin           static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData))[I]);
48555e12f70SMircea Trofin   else if (Spec.isElementType<int64_t>() || Spec.isElementType<float>())
48655e12f70SMircea Trofin     std::memcpy(Buff, RawData,
48755e12f70SMircea Trofin                 Spec.getElementCount() * Spec.getElementByteSize());
48855e12f70SMircea Trofin   else
48955e12f70SMircea Trofin     llvm_unreachable("Unsupported tensor type");
49055e12f70SMircea Trofin }
49155e12f70SMircea Trofin 
addEntryAndGetFloatOrInt64Buffer(size_t FeatureID)49255e12f70SMircea Trofin char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) {
49355e12f70SMircea Trofin   return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID));
49455e12f70SMircea Trofin }
49555e12f70SMircea Trofin 
flush(std::string * Str)4961f5dceb1SMircea Trofin void Logger::flush(std::string *Str) { LoggerData->flush(Str); }
4971f5dceb1SMircea Trofin 
flush(raw_ostream & OS)4981f5dceb1SMircea Trofin void Logger::flush(raw_ostream &OS) {
4991f5dceb1SMircea Trofin   std::string Buff;
5001f5dceb1SMircea Trofin   LoggerData->flush(&Buff);
5011f5dceb1SMircea Trofin   OS << Buff;
5021f5dceb1SMircea Trofin }
5031f5dceb1SMircea Trofin 
flushLogs(raw_ostream & OS,const StringMap<std::unique_ptr<Logger>> & Loggers)5041f5dceb1SMircea Trofin void Logger::flushLogs(raw_ostream &OS,
5051f5dceb1SMircea Trofin                        const StringMap<std::unique_ptr<Logger>> &Loggers) {
5061f5dceb1SMircea Trofin   google::protobuf::Struct Msg;
5071f5dceb1SMircea Trofin   for (const auto &NamedLogger : Loggers) {
5081f5dceb1SMircea Trofin     tensorflow::SequenceExample SE;
5091f5dceb1SMircea Trofin     const auto &Logger = NamedLogger.second;
5101f5dceb1SMircea Trofin     std::string Unencoded;
5111f5dceb1SMircea Trofin     if (Logger->LoggerData->getNrRecords() > 0)
5121f5dceb1SMircea Trofin       Logger->flush(&Unencoded);
5131f5dceb1SMircea Trofin 
5141f5dceb1SMircea Trofin     (*Msg.mutable_fields())[NamedLogger.first().str()]
5151f5dceb1SMircea Trofin         .mutable_string_value()
5161f5dceb1SMircea Trofin         ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded));
5171f5dceb1SMircea Trofin   }
5181f5dceb1SMircea Trofin 
5191f5dceb1SMircea Trofin   std::string OutStr;
5201f5dceb1SMircea Trofin   serialize(Msg, &OutStr);
5211f5dceb1SMircea Trofin   OS << OutStr;
5221f5dceb1SMircea Trofin }
5234fe912f1SNico Weber #endif // defined(LLVM_HAVE_TF_API)
524