1 //===- TFUtils.cpp - tensorflow evaluation utilities ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for interfacing with tensorflow C APIs.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "llvm/Config/config.h"
13 #if defined(LLVM_HAVE_TF_API)
14 
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Analysis/Utils/TFUtils.h"
17 #include "llvm/Support/Base64.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/JSON.h"
21 #include "llvm/Support/ManagedStatic.h"
22 #include "llvm/Support/MemoryBuffer.h"
23 #include "llvm/Support/Path.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 #include "google/protobuf/struct.pb.h"
27 #include "google/protobuf/text_format.h"
28 #include "tensorflow/c/c_api.h"
29 #include "tensorflow/c/c_api_experimental.h"
30 #include "tensorflow/core/example/example.pb.h"
31 #include <cassert>
32 #include <numeric>
33 
34 using namespace llvm;
35 
36 using google::protobuf::Message;
37 using google::protobuf::TextFormat;
38 
39 static cl::opt<bool>
40     ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden,
41                      cl::desc("Output textual (human-readable) protobuf."));
42 
43 namespace {
44 
45 using TFGraphPtr = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>;
46 using TFSessionOptionsPtr =
47     std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>;
48 using TFStatusPtr = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
49 
50 struct TFInitializer {
51   TFInitializer() {
52     assert(!IsInitialized && "TFInitialized should be called only once");
53     int Argc = 1;
54     const char *Name = "";
55     const char **NamePtr = &Name;
56     TF_InitMain(Name, &Argc, const_cast<char ***>(&NamePtr));
57     IsInitialized = true;
58   }
59   bool IsInitialized = false;
60 };
61 
62 llvm::ManagedStatic<TFInitializer> TFLibInitializer;
63 
64 bool ensureInitTF() { return TFLibInitializer->IsInitialized; }
65 
66 TFGraphPtr createTFGraph() {
67   return TFGraphPtr(TF_NewGraph(), &TF_DeleteGraph);
68 }
69 
70 TFStatusPtr createTFStatus() {
71   return TFStatusPtr(TF_NewStatus(), &TF_DeleteStatus);
72 }
73 
74 TFSessionOptionsPtr createTFSessionOptions() {
75   return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
76 }
77 
78 void serialize(const Message &SE, std::string *OutStr) {
79   if (ProtobufTextMode) {
80     TextFormat::PrintToString(SE, OutStr);
81   } else {
82     *OutStr = SE.SerializeAsString();
83   }
84 }
85 } // namespace
86 
87 namespace llvm {
88 class EvaluationResultImpl {
89 public:
90   EvaluationResultImpl(size_t OutputSize)
91       : OutputSize(OutputSize), Output(OutputSize){};
92 
93   ~EvaluationResultImpl() {
94     for (auto *P : Output)
95       if (P)
96         TF_DeleteTensor(P);
97   }
98 
99   EvaluationResultImpl(const EvaluationResultImpl &) = delete;
100   EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
101   std::vector<TF_Tensor *> &getOutput() { return Output; }
102 
103 private:
104   const size_t OutputSize;
105   std::vector<TF_Tensor *> Output;
106 };
107 
108 size_t TensorSpec::getElementByteSize() const {
109   return TF_DataTypeSize(static_cast<TF_DataType>(TypeIndex));
110 }
111 
112 TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex,
113                        const std::vector<int64_t> &Shape)
114     : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape),
115       ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
116                                    std::multiplies<int64_t>())) {}
117 
118 Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
119                                            const json::Value &Value) {
120   auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
121     std::string S;
122     llvm::raw_string_ostream OS(S);
123     OS << Value;
124     Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
125     return None;
126   };
127   // FIXME: accept a Path as a parameter, and use it for error reporting.
128   json::Path::Root Root("tensor_spec");
129   json::ObjectMapper Mapper(Value, Root);
130   if (!Mapper)
131     return EmitError("Value is not a dict");
132 
133   std::string TensorName;
134   int TensorPort = -1;
135   std::string TensorType;
136   std::vector<int64_t> TensorShape;
137 
138   if (!Mapper.map<std::string>("name", TensorName))
139     return EmitError("'name' property not present or not a string");
140   if (!Mapper.map<std::string>("type", TensorType))
141     return EmitError("'type' property not present or not a string");
142   if (!Mapper.map<int>("port", TensorPort))
143     return EmitError("'port' property not present or not an int");
144   if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
145     return EmitError("'shape' property not present or not an int array");
146 
147 #define PARSE_TYPE(T, E)                                                       \
148   if (TensorType == #T)                                                        \
149     return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
150   TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
151 #undef PARSE_TYPE
152   return None;
153 }
154 
155 Optional<std::vector<LoggedFeatureSpec>>
156 loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
157                 StringRef ModelPath, StringRef SpecFileOverride) {
158   SmallVector<char, 128> OutputSpecsPath;
159   StringRef FileName = SpecFileOverride;
160   if (FileName.empty()) {
161     llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
162     FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
163   }
164 
165   auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
166   if (!BufferOrError) {
167     Ctx.emitError("Error opening output specs file: " + FileName + " : " +
168                   BufferOrError.getError().message());
169     return None;
170   }
171   auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
172   if (!ParsedJSONValues) {
173     Ctx.emitError("Could not parse specs file: " + FileName);
174     return None;
175   }
176   auto ValuesArray = ParsedJSONValues->getAsArray();
177   if (!ValuesArray) {
178     Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
179                   "logging_name:<name>} dictionaries");
180     return None;
181   }
182   std::vector<LoggedFeatureSpec> Ret;
183   for (const auto &Value : *ValuesArray)
184     if (const auto *Obj = Value.getAsObject())
185       if (const auto *SpecPart = Obj->get("tensor_spec"))
186         if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
187           if (auto LoggingName = Obj->getString("logging_name")) {
188             if (!TensorSpec->isElementType<int64_t>() &&
189                 !TensorSpec->isElementType<int32_t>() &&
190                 !TensorSpec->isElementType<float>()) {
191               Ctx.emitError(
192                   "Only int64, int32, and float tensors are supported. "
193                   "Found unsupported type for tensor named " +
194                   TensorSpec->name());
195               return None;
196             }
197             Ret.push_back({*TensorSpec, LoggingName->str()});
198           }
199 
200   if (ValuesArray->size() != Ret.size()) {
201     Ctx.emitError(
202         "Unable to parse output spec. It should be a json file containing an "
203         "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
204         "with a json object describing a TensorSpec; and a 'logging_name' key, "
205         "which is a string to use as name when logging this tensor in the "
206         "training log.");
207     return None;
208   }
209   if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
210     Ctx.emitError("The first output spec must describe the decision tensor, "
211                   "and must have the logging_name " +
212                   StringRef(ExpectedDecisionName));
213     return None;
214   }
215   return Ret;
216 }
217 
218 class TFModelEvaluatorImpl {
219 public:
220   TFModelEvaluatorImpl(StringRef SavedModelPath,
221                        const std::vector<TensorSpec> &InputSpecs,
222                        function_ref<TensorSpec(size_t)> GetOutputSpecs,
223                        size_t OutputSpecsSize, const char *Tags);
224 
225   bool isValid() const { return IsValid; }
226   size_t OutputSize() const { return OutputFeed.size(); }
227 
228   void evaluate(TF_Tensor **Output, TF_Status *Status) {
229     TF_SessionRun(Session, nullptr, InputFeed.data(), Input.data(),
230                   Input.size(), OutputFeed.data(), Output, OutputFeed.size(),
231                   nullptr, 0, nullptr, Status);
232   }
233 
234   void initInput(size_t Index, TF_DataType Type,
235                  const std::vector<int64_t> &Dimensions);
236   const std::vector<TF_Tensor *> &getInput() const { return Input; }
237 
238   ~TFModelEvaluatorImpl();
239 
240 private:
241   /// The objects necessary for carrying out an evaluation of the SavedModel.
242   /// They are expensive to set up, and we maintain them accross all the
243   /// evaluations of the model.
244   TF_Session *Session = nullptr;
245   TFGraphPtr Graph;
246   TFSessionOptionsPtr Options;
247 
248   /// The specification of the input nodes.
249   std::vector<TF_Output> InputFeed;
250 
251   /// The input tensors. They must match by index of the corresponding InputFeed
252   /// value. We set up the tensors once and just mutate theirs scalars before
253   /// each evaluation. The input tensors keep their value after an evaluation.
254   std::vector<TF_Tensor *> Input;
255 
256   /// The specification of the output nodes. When evaluating, the tensors in the
257   /// output tensor vector must match by index the corresponding element in the
258   /// OutputFeed.
259   std::vector<TF_Output> OutputFeed;
260 
261   void invalidate() { IsValid = false; }
262 
263   bool IsValid = true;
264 
265   /// Reusable utility for ensuring we can bind the requested Name to a node in
266   /// the SavedModel Graph.
267   bool checkReportAndInvalidate(const TF_Output &Output,
268                                 const TensorSpec &OutputSpec);
269 };
270 
271 class LoggerDataImpl {
272   const std::vector<LoggedFeatureSpec> LoggedFeatureSpecs;
273   const TensorSpec RewardSpec;
274   const bool IncludeReward;
275 
276   std::vector<tensorflow::FeatureList> FeatureLists;
277   tensorflow::FeatureList Reward;
278 
279   bool isSelfConsistent(const tensorflow::SequenceExample &SE,
280                         size_t NrRecords) const {
281     bool Ret = true;
282     for (const auto &TSpecs : LoggedFeatureSpecs) {
283       const auto &Name = TSpecs.getLoggingName();
284       const auto &FL = SE.feature_lists().feature_list().at(Name).feature();
285       if (NrRecords != static_cast<size_t>(FL.size())) {
286         dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected "
287                << NrRecords << " got " << FL.size() << "\n";
288         Ret = false;
289       }
290     }
291     if (IncludeReward && static_cast<size_t>(SE.feature_lists()
292                                                  .feature_list()
293                                                  .at(RewardSpec.name())
294                                                  .feature()
295                                                  .size()) != NrRecords) {
296       dbgs() << "[TF-UTILS]: reward is missing records.\n";
297       Ret = false;
298     }
299     return Ret;
300   }
301 
302   void transferLog(tensorflow::SequenceExample &SE) {
303     auto *FL = SE.mutable_feature_lists()->mutable_feature_list();
304     if (IncludeReward)
305       (*FL)[RewardSpec.name()] = std::move(Reward);
306     assert(FeatureLists.size() == LoggedFeatureSpecs.size());
307     for (size_t I = 0; I < FeatureLists.size(); ++I) {
308       const auto &LFS = LoggedFeatureSpecs[I];
309       (*FL)[LFS.getLoggingName()] = std::move(FeatureLists[I]);
310     }
311   }
312 
313 public:
314   LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs,
315                  const TensorSpec &RewardSpec, bool IncludeReward)
316       : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
317         IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
318 
319   // flush the logged info to a stream and clear the log contents.
320   void flush(std::string *Str) {
321     size_t NrRecords = getNrRecords();
322     (void)NrRecords;
323     tensorflow::SequenceExample SE;
324     transferLog(SE);
325     assert(isSelfConsistent(SE, NrRecords));
326     serialize(SE, Str);
327   }
328 
329   char *addNewTensor(size_t FeatureID) {
330     const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec;
331     if (Spec.isElementType<float>()) {
332       auto *RF = FeatureLists[FeatureID]
333                      .add_feature()
334                      ->mutable_float_list()
335                      ->mutable_value();
336       RF->Resize(Spec.getElementCount(), 0.0);
337       return reinterpret_cast<char *>(RF->mutable_data());
338     } else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) {
339       auto *RF = FeatureLists[FeatureID]
340                      .add_feature()
341                      ->mutable_int64_list()
342                      ->mutable_value();
343       RF->Resize(Spec.getElementCount(), 0);
344       return reinterpret_cast<char *>(RF->mutable_data());
345     }
346     llvm_unreachable("Unsupported tensor type.");
347   }
348 
349   template <typename T> void logReward(T Value) {
350     assert(IncludeReward);
351     if (RewardSpec.isElementType<float>())
352       Reward.add_feature()->mutable_float_list()->add_value(Value);
353     else if (RewardSpec.isElementType<int32_t>() ||
354              RewardSpec.isElementType<int64_t>())
355       Reward.add_feature()->mutable_int64_list()->add_value(Value);
356     else
357       llvm_unreachable("Unsupported tensor type.");
358   }
359 
360   size_t getNrRecords() const {
361     return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
362   }
363 };
364 } // namespace llvm
365 
366 TFModelEvaluatorImpl::TFModelEvaluatorImpl(
367     StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
368     function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
369     const char *Tags = "serve")
370     : Graph(createTFGraph()), Options(createTFSessionOptions()),
371       InputFeed(InputSpecs.size()), Input(InputSpecs.size()),
372       OutputFeed(OutputSpecsSize) {
373   if (!ensureInitTF()) {
374     errs() << "Tensorflow should have been initialized";
375     return;
376   }
377   auto Status = createTFStatus();
378 
379   Session = TF_LoadSessionFromSavedModel(Options.get(), nullptr,
380                                          SavedModelPath.str().c_str(), &Tags, 1,
381                                          Graph.get(), nullptr, Status.get());
382   if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
383     errs() << TF_Message(Status.get());
384     invalidate();
385   }
386   for (size_t I = 0; I < InputSpecs.size(); ++I) {
387     auto &InputSpec = InputSpecs[I];
388     InputFeed[I] = {
389         TF_GraphOperationByName(Graph.get(), (InputSpec.name()).c_str()),
390         InputSpec.port()};
391     if (!checkReportAndInvalidate(InputFeed[I], InputSpec))
392       return;
393     initInput(I, static_cast<TF_DataType>(InputSpec.typeIndex()),
394               InputSpec.shape());
395   }
396   for (size_t I = 0; I < OutputSpecsSize; ++I) {
397     auto OutputSpec = GetOutputSpecs(I);
398     OutputFeed[I] = {
399         TF_GraphOperationByName(Graph.get(), (OutputSpec.name()).c_str()),
400         OutputSpec.port()};
401     if (!checkReportAndInvalidate(OutputFeed[I], OutputSpec))
402       return;
403   }
404 }
405 
406 TFModelEvaluator::TFModelEvaluator(
407     StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
408     function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
409     const char *Tags)
410     : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, GetOutputSpecs,
411                                     OutputSpecsSize, Tags)) {
412   if (!Impl->isValid())
413     Impl.reset();
414 }
415 
416 TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
417                                    const std::vector<TensorSpec> &InputSpecs,
418                                    const std::vector<TensorSpec> &OutputSpecs,
419                                    const char *Tags)
420     : TFModelEvaluator(
421           SavedModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I]; },
422           OutputSpecs.size(), Tags) {}
423 
424 TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {
425   for (auto *T : Input) {
426     TF_DeleteTensor(T);
427   }
428   if (Session == nullptr)
429     return;
430   auto Status = createTFStatus();
431   TF_DeleteSession(Session, Status.get());
432   Session = nullptr;
433   if (TF_GetCode(Status.get()) != TF_Code::TF_OK)
434     errs() << "Could not delete TF session";
435 }
436 
437 bool TFModelEvaluatorImpl::checkReportAndInvalidate(
438     const TF_Output &Output, const TensorSpec &OutputSpec) {
439   if (Output.oper)
440     return true;
441   errs() << "Could not find TF_Output named: " + OutputSpec.name();
442   IsValid = false;
443   return IsValid;
444 }
445 
446 Optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
447   if (!isValid())
448     return None;
449   std::unique_ptr<EvaluationResultImpl> Ret =
450       std::make_unique<EvaluationResultImpl>(Impl->OutputSize());
451   auto Status = createTFStatus();
452   Impl->evaluate(Ret->getOutput().data(), Status.get());
453   if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
454     errs() << TF_Message(Status.get());
455     Impl.reset();
456     return None;
457   }
458   return EvaluationResult(std::move(Ret));
459 }
460 
461 void TFModelEvaluatorImpl::initInput(size_t Index, TF_DataType Type,
462                                      const std::vector<int64_t> &Dimensions) {
463   int64_t TotalSize = TF_DataTypeSize(Type);
464   for (auto &D : Dimensions)
465     TotalSize *= D;
466 
467   Input[Index] =
468       TF_AllocateTensor(Type, Dimensions.data(), Dimensions.size(), TotalSize);
469   std::memset(TF_TensorData(Input[Index]), 0, TotalSize);
470 }
471 
472 void *TFModelEvaluator::getUntypedInput(size_t Index) {
473   return TF_TensorData(Impl->getInput()[Index]);
474 }
475 
476 TFModelEvaluator::EvaluationResult::EvaluationResult(
477     std::unique_ptr<EvaluationResultImpl> Impl)
478     : Impl(std::move(Impl)) {}
479 
480 TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
481     : Impl(std::move(Other.Impl)) {}
482 
483 TFModelEvaluator::EvaluationResult &
484 TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
485   Impl = std::move(Other.Impl);
486   return *this;
487 }
488 
489 void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
490   return TF_TensorData(Impl->getOutput()[Index]);
491 }
492 
493 const void *
494 TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
495   return TF_TensorData(Impl->getOutput()[Index]);
496 }
497 
498 #define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
499   template <> int TensorSpec::getDataType<T>() { return E; }
500 
501 TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
502 
503 #undef TFUTILS_GETDATATYPE_IMPL
504 
505 TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
506 TFModelEvaluator::~TFModelEvaluator() {}
507 
508 Logger::Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs,
509                const TensorSpec &RewardSpec, bool IncludeReward)
510     : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
511       IncludeReward(IncludeReward),
512       LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
513                                                   IncludeReward)) {}
514 
515 Logger::~Logger() {}
516 
517 #define LOG_REWARD(NAME, TYPE)                                                 \
518   void Logger::log##NAME##Reward(TYPE Value) {                                 \
519     assert(IncludeReward);                                                     \
520     LoggerData->logReward(Value);                                              \
521   }
522 
523 LOG_REWARD(Float, float)
524 LOG_REWARD(Int32, int32_t)
525 LOG_REWARD(Int64, int64_t)
526 #undef LOG_REWARD
527 
528 #define LOG_FINAL_REWARD(NAME, TYPE)                                           \
529   void Logger::log##NAME##FinalReward(TYPE Value) {                            \
530     assert(RewardSpec.isElementType<TYPE>());                                  \
531     for (size_t I = 1; I < LoggerData->getNrRecords(); ++I)                    \
532       log##NAME##Reward(0);                                                    \
533     log##NAME##Reward(Value);                                                  \
534   }
535 
536 LOG_FINAL_REWARD(Float, float)
537 LOG_FINAL_REWARD(Int32, int32_t)
538 LOG_FINAL_REWARD(Int64, int64_t)
539 #undef LOG_FINAL_REWARD
540 
541 void Logger::logFloatValue(size_t FeatureID, const float *Value) {
542   assert(FeatureSpecs[FeatureID].Spec.isElementType<float>());
543   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
544 }
545 
546 void Logger::logInt64Value(size_t FeatureID, const int64_t *Value) {
547   assert(FeatureSpecs[FeatureID].Spec.isElementType<int64_t>());
548   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
549 }
550 
551 void Logger::logInt32Value(size_t FeatureID, const int32_t *Value) {
552   assert(FeatureSpecs[FeatureID].Spec.isElementType<int32_t>());
553   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
554 }
555 
556 void Logger::logSpecifiedTensorValue(size_t FeatureID, const char *RawData) {
557   const auto &Spec = FeatureSpecs[FeatureID].Spec;
558   char *Buff = addEntryAndGetFloatOrInt64Buffer(FeatureID);
559   if (Spec.isElementType<int32_t>())
560     for (size_t I = 0; I < Spec.getElementCount(); ++I)
561       (reinterpret_cast<int64_t *>(Buff))[I] =
562           static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData))[I]);
563   else if (Spec.isElementType<int64_t>() || Spec.isElementType<float>())
564     std::memcpy(Buff, RawData,
565                 Spec.getElementCount() * Spec.getElementByteSize());
566   else
567     llvm_unreachable("Unsupported tensor type");
568 }
569 
570 char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) {
571   return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID));
572 }
573 
574 void Logger::flush(std::string *Str) { LoggerData->flush(Str); }
575 
576 void Logger::flush(raw_ostream &OS) {
577   std::string Buff;
578   LoggerData->flush(&Buff);
579   OS << Buff;
580 }
581 
582 void Logger::flushLogs(raw_ostream &OS,
583                        const StringMap<std::unique_ptr<Logger>> &Loggers) {
584   google::protobuf::Struct Msg;
585   for (const auto &NamedLogger : Loggers) {
586     tensorflow::SequenceExample SE;
587     const auto &Logger = NamedLogger.second;
588     std::string Unencoded;
589     if (Logger->LoggerData->getNrRecords() > 0)
590       Logger->flush(&Unencoded);
591 
592     (*Msg.mutable_fields())[NamedLogger.first().str()]
593         .mutable_string_value()
594         ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded));
595   }
596 
597   std::string OutStr;
598   serialize(Msg, &OutStr);
599   OS << OutStr;
600 }
601 #endif // defined(LLVM_HAVE_TF_API)
602