1 //===- TFUtils.cpp - tensorflow evaluation utilities ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for interfacing with tensorflow C APIs.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "llvm/Config/config.h"
13 #if defined(LLVM_HAVE_TF_API)
14 
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Analysis/Utils/TFUtils.h"
17 #include "llvm/Support/Base64.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/JSON.h"
21 #include "llvm/Support/ManagedStatic.h"
22 #include "llvm/Support/MemoryBuffer.h"
23 #include "llvm/Support/Path.h"
24 #include "llvm/Support/raw_ostream.h"
25 
26 #include "google/protobuf/struct.pb.h"
27 #include "google/protobuf/text_format.h"
28 #include "tensorflow/c/c_api.h"
29 #include "tensorflow/c/c_api_experimental.h"
30 #include "tensorflow/core/example/example.pb.h"
31 #include <cassert>
32 #include <numeric>
33 
34 using namespace llvm;
35 
36 using google::protobuf::Message;
37 using google::protobuf::TextFormat;
38 
39 static cl::opt<bool>
40     ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden,
41                      cl::desc("Output textual (human-readable) protobuf."));
42 
43 namespace {
44 
45 using TFGraphPtr = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>;
46 using TFSessionOptionsPtr =
47     std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>;
48 using TFStatusPtr = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
49 
50 struct TFInitializer {
51   TFInitializer() {
52     assert(!IsInitialized && "TFInitialized should be called only once");
53     int Argc = 1;
54     const char *Name = "";
55     const char **NamePtr = &Name;
56     TF_InitMain(Name, &Argc, const_cast<char ***>(&NamePtr));
57     IsInitialized = true;
58   }
59   bool IsInitialized = false;
60 };
61 
62 llvm::ManagedStatic<TFInitializer> TFLibInitializer;
63 
64 bool ensureInitTF() { return TFLibInitializer->IsInitialized; }
65 
66 TFGraphPtr createTFGraph() {
67   return TFGraphPtr(TF_NewGraph(), &TF_DeleteGraph);
68 }
69 
70 TFStatusPtr createTFStatus() {
71   return TFStatusPtr(TF_NewStatus(), &TF_DeleteStatus);
72 }
73 
74 TFSessionOptionsPtr createTFSessionOptions() {
75   return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
76 }
77 
78 void serialize(const Message &SE, std::string *OutStr) {
79   if (ProtobufTextMode) {
80     TextFormat::PrintToString(SE, OutStr);
81   } else {
82     *OutStr = SE.SerializeAsString();
83   }
84 }
85 
86 int getTFTypeIndex(TensorType TType) {
87   switch (TType) {
88   case TensorType::Double:
89     return TF_DOUBLE;
90   case TensorType::Float:
91     return TF_FLOAT;
92   case TensorType::Int8:
93     return TF_INT8;
94   case TensorType::UInt8:
95     return TF_UINT8;
96   case TensorType::Int16:
97     return TF_INT16;
98   case TensorType::UInt16:
99     return TF_UINT16;
100   case TensorType::Int32:
101     return TF_INT32;
102   case TensorType::UInt32:
103     return TF_UINT32;
104   case TensorType::Int64:
105     return TF_INT64;
106   case TensorType::UInt64:
107     return TF_UINT64;
108   case TensorType::Invalid:
109     llvm_unreachable("Unknown tensor type");
110   }
111 }
112 } // namespace
113 
114 namespace llvm {
115 class EvaluationResultImpl {
116 public:
117   EvaluationResultImpl(size_t OutputSize)
118       : OutputSize(OutputSize), Output(OutputSize){};
119 
120   ~EvaluationResultImpl() {
121     for (auto *P : Output)
122       if (P)
123         TF_DeleteTensor(P);
124   }
125 
126   EvaluationResultImpl(const EvaluationResultImpl &) = delete;
127   EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
128   std::vector<TF_Tensor *> &getOutput() { return Output; }
129 
130 private:
131   const size_t OutputSize;
132   std::vector<TF_Tensor *> Output;
133 };
134 
135 TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
136                        size_t ElementSize, const std::vector<int64_t> &Shape)
137     : Name(Name), Port(Port), Type(Type), Shape(Shape),
138       ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
139                                    std::multiplies<int64_t>())),
140       ElementSize(ElementSize) {}
141 
142 Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
143                                            const json::Value &Value) {
144   auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
145     std::string S;
146     llvm::raw_string_ostream OS(S);
147     OS << Value;
148     Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
149     return None;
150   };
151   // FIXME: accept a Path as a parameter, and use it for error reporting.
152   json::Path::Root Root("tensor_spec");
153   json::ObjectMapper Mapper(Value, Root);
154   if (!Mapper)
155     return EmitError("Value is not a dict");
156 
157   std::string TensorName;
158   int TensorPort = -1;
159   std::string TensorType;
160   std::vector<int64_t> TensorShape;
161 
162   if (!Mapper.map<std::string>("name", TensorName))
163     return EmitError("'name' property not present or not a string");
164   if (!Mapper.map<std::string>("type", TensorType))
165     return EmitError("'type' property not present or not a string");
166   if (!Mapper.map<int>("port", TensorPort))
167     return EmitError("'port' property not present or not an int");
168   if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
169     return EmitError("'shape' property not present or not an int array");
170 
171 #define PARSE_TYPE(T, E)                                                       \
172   if (TensorType == #T)                                                        \
173     return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
174   SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
175 #undef PARSE_TYPE
176   return None;
177 }
178 
179 Optional<std::vector<LoggedFeatureSpec>>
180 loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
181                 StringRef ModelPath, StringRef SpecFileOverride) {
182   SmallVector<char, 128> OutputSpecsPath;
183   StringRef FileName = SpecFileOverride;
184   if (FileName.empty()) {
185     llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
186     FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
187   }
188 
189   auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
190   if (!BufferOrError) {
191     Ctx.emitError("Error opening output specs file: " + FileName + " : " +
192                   BufferOrError.getError().message());
193     return None;
194   }
195   auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
196   if (!ParsedJSONValues) {
197     Ctx.emitError("Could not parse specs file: " + FileName);
198     return None;
199   }
200   auto ValuesArray = ParsedJSONValues->getAsArray();
201   if (!ValuesArray) {
202     Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
203                   "logging_name:<name>} dictionaries");
204     return None;
205   }
206   std::vector<LoggedFeatureSpec> Ret;
207   for (const auto &Value : *ValuesArray)
208     if (const auto *Obj = Value.getAsObject())
209       if (const auto *SpecPart = Obj->get("tensor_spec"))
210         if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
211           if (auto LoggingName = Obj->getString("logging_name")) {
212             if (!TensorSpec->isElementType<int64_t>() &&
213                 !TensorSpec->isElementType<int32_t>() &&
214                 !TensorSpec->isElementType<float>()) {
215               Ctx.emitError(
216                   "Only int64, int32, and float tensors are supported. "
217                   "Found unsupported type for tensor named " +
218                   TensorSpec->name());
219               return None;
220             }
221             Ret.push_back({*TensorSpec, LoggingName->str()});
222           }
223 
224   if (ValuesArray->size() != Ret.size()) {
225     Ctx.emitError(
226         "Unable to parse output spec. It should be a json file containing an "
227         "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
228         "with a json object describing a TensorSpec; and a 'logging_name' key, "
229         "which is a string to use as name when logging this tensor in the "
230         "training log.");
231     return None;
232   }
233   if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
234     Ctx.emitError("The first output spec must describe the decision tensor, "
235                   "and must have the logging_name " +
236                   StringRef(ExpectedDecisionName));
237     return None;
238   }
239   return Ret;
240 }
241 
242 class TFModelEvaluatorImpl {
243 public:
244   TFModelEvaluatorImpl(StringRef SavedModelPath,
245                        const std::vector<TensorSpec> &InputSpecs,
246                        function_ref<TensorSpec(size_t)> GetOutputSpecs,
247                        size_t OutputSpecsSize, const char *Tags);
248 
249   bool isValid() const { return IsValid; }
250   size_t OutputSize() const { return OutputFeed.size(); }
251 
252   void evaluate(TF_Tensor **Output, TF_Status *Status) {
253     TF_SessionRun(Session, nullptr, InputFeed.data(), Input.data(),
254                   Input.size(), OutputFeed.data(), Output, OutputFeed.size(),
255                   nullptr, 0, nullptr, Status);
256   }
257 
258   void initInput(size_t Index, TF_DataType Type,
259                  const std::vector<int64_t> &Dimensions);
260   const std::vector<TF_Tensor *> &getInput() const { return Input; }
261 
262   ~TFModelEvaluatorImpl();
263 
264 private:
265   /// The objects necessary for carrying out an evaluation of the SavedModel.
266   /// They are expensive to set up, and we maintain them accross all the
267   /// evaluations of the model.
268   TF_Session *Session = nullptr;
269   TFGraphPtr Graph;
270   TFSessionOptionsPtr Options;
271 
272   /// The specification of the input nodes.
273   std::vector<TF_Output> InputFeed;
274 
275   /// The input tensors. They must match by index of the corresponding InputFeed
276   /// value. We set up the tensors once and just mutate theirs scalars before
277   /// each evaluation. The input tensors keep their value after an evaluation.
278   std::vector<TF_Tensor *> Input;
279 
280   /// The specification of the output nodes. When evaluating, the tensors in the
281   /// output tensor vector must match by index the corresponding element in the
282   /// OutputFeed.
283   std::vector<TF_Output> OutputFeed;
284 
285   void invalidate() { IsValid = false; }
286 
287   bool IsValid = true;
288 
289   /// Reusable utility for ensuring we can bind the requested Name to a node in
290   /// the SavedModel Graph.
291   bool checkReportAndInvalidate(const TF_Output &Output,
292                                 const TensorSpec &OutputSpec);
293 };
294 
295 class LoggerDataImpl {
296   const std::vector<LoggedFeatureSpec> LoggedFeatureSpecs;
297   const TensorSpec RewardSpec;
298   const bool IncludeReward;
299 
300   std::vector<tensorflow::FeatureList> FeatureLists;
301   tensorflow::FeatureList Reward;
302 
303   bool isSelfConsistent(const tensorflow::SequenceExample &SE,
304                         size_t NrRecords) const {
305     bool Ret = true;
306     for (const auto &TSpecs : LoggedFeatureSpecs) {
307       const auto &Name = TSpecs.getLoggingName();
308       const auto &FL = SE.feature_lists().feature_list().at(Name).feature();
309       if (NrRecords != static_cast<size_t>(FL.size())) {
310         dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected "
311                << NrRecords << " got " << FL.size() << "\n";
312         Ret = false;
313       }
314     }
315     if (IncludeReward && static_cast<size_t>(SE.feature_lists()
316                                                  .feature_list()
317                                                  .at(RewardSpec.name())
318                                                  .feature()
319                                                  .size()) != NrRecords) {
320       dbgs() << "[TF-UTILS]: reward is missing records.\n";
321       Ret = false;
322     }
323     return Ret;
324   }
325 
326   void transferLog(tensorflow::SequenceExample &SE) {
327     auto *FL = SE.mutable_feature_lists()->mutable_feature_list();
328     if (IncludeReward)
329       (*FL)[RewardSpec.name()] = std::move(Reward);
330     assert(FeatureLists.size() == LoggedFeatureSpecs.size());
331     for (size_t I = 0; I < FeatureLists.size(); ++I) {
332       const auto &LFS = LoggedFeatureSpecs[I];
333       (*FL)[LFS.getLoggingName()] = std::move(FeatureLists[I]);
334     }
335   }
336 
337 public:
338   LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs,
339                  const TensorSpec &RewardSpec, bool IncludeReward)
340       : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
341         IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
342 
343   // flush the logged info to a stream and clear the log contents.
344   void flush(std::string *Str) {
345     size_t NrRecords = getNrRecords();
346     (void)NrRecords;
347     tensorflow::SequenceExample SE;
348     transferLog(SE);
349     assert(isSelfConsistent(SE, NrRecords));
350     serialize(SE, Str);
351   }
352 
353   char *addNewTensor(size_t FeatureID) {
354     const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec;
355     if (Spec.isElementType<float>()) {
356       auto *RF = FeatureLists[FeatureID]
357                      .add_feature()
358                      ->mutable_float_list()
359                      ->mutable_value();
360       RF->Resize(Spec.getElementCount(), 0.0);
361       return reinterpret_cast<char *>(RF->mutable_data());
362     } else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) {
363       auto *RF = FeatureLists[FeatureID]
364                      .add_feature()
365                      ->mutable_int64_list()
366                      ->mutable_value();
367       RF->Resize(Spec.getElementCount(), 0);
368       return reinterpret_cast<char *>(RF->mutable_data());
369     }
370     llvm_unreachable("Unsupported tensor type.");
371   }
372 
373   template <typename T> void logReward(T Value) {
374     assert(IncludeReward);
375     if (RewardSpec.isElementType<float>())
376       Reward.add_feature()->mutable_float_list()->add_value(Value);
377     else if (RewardSpec.isElementType<int32_t>() ||
378              RewardSpec.isElementType<int64_t>())
379       Reward.add_feature()->mutable_int64_list()->add_value(Value);
380     else
381       llvm_unreachable("Unsupported tensor type.");
382   }
383 
384   size_t getNrRecords() const {
385     return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
386   }
387 };
388 } // namespace llvm
389 
390 TFModelEvaluatorImpl::TFModelEvaluatorImpl(
391     StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
392     function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
393     const char *Tags = "serve")
394     : Graph(createTFGraph()), Options(createTFSessionOptions()),
395       InputFeed(InputSpecs.size()), Input(InputSpecs.size()),
396       OutputFeed(OutputSpecsSize) {
397   if (!ensureInitTF()) {
398     errs() << "Tensorflow should have been initialized";
399     return;
400   }
401   auto Status = createTFStatus();
402 
403   Session = TF_LoadSessionFromSavedModel(Options.get(), nullptr,
404                                          SavedModelPath.str().c_str(), &Tags, 1,
405                                          Graph.get(), nullptr, Status.get());
406   if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
407     errs() << TF_Message(Status.get());
408     invalidate();
409   }
410   for (size_t I = 0; I < InputSpecs.size(); ++I) {
411     auto &InputSpec = InputSpecs[I];
412     InputFeed[I] = {
413         TF_GraphOperationByName(Graph.get(), (InputSpec.name()).c_str()),
414         InputSpec.port()};
415     if (!checkReportAndInvalidate(InputFeed[I], InputSpec))
416       return;
417     initInput(I, static_cast<TF_DataType>(getTFTypeIndex(InputSpec.type())),
418               InputSpec.shape());
419   }
420   for (size_t I = 0; I < OutputSpecsSize; ++I) {
421     auto OutputSpec = GetOutputSpecs(I);
422     OutputFeed[I] = {
423         TF_GraphOperationByName(Graph.get(), (OutputSpec.name()).c_str()),
424         OutputSpec.port()};
425     if (!checkReportAndInvalidate(OutputFeed[I], OutputSpec))
426       return;
427   }
428 }
429 
430 TFModelEvaluator::TFModelEvaluator(
431     StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
432     function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
433     const char *Tags)
434     : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, GetOutputSpecs,
435                                     OutputSpecsSize, Tags)) {
436   if (!Impl->isValid())
437     Impl.reset();
438 }
439 
440 TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
441                                    const std::vector<TensorSpec> &InputSpecs,
442                                    const std::vector<TensorSpec> &OutputSpecs,
443                                    const char *Tags)
444     : TFModelEvaluator(
445           SavedModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I]; },
446           OutputSpecs.size(), Tags) {}
447 
448 TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {
449   for (auto *T : Input) {
450     TF_DeleteTensor(T);
451   }
452   if (Session == nullptr)
453     return;
454   auto Status = createTFStatus();
455   TF_DeleteSession(Session, Status.get());
456   Session = nullptr;
457   if (TF_GetCode(Status.get()) != TF_Code::TF_OK)
458     errs() << "Could not delete TF session";
459 }
460 
461 bool TFModelEvaluatorImpl::checkReportAndInvalidate(
462     const TF_Output &Output, const TensorSpec &OutputSpec) {
463   if (Output.oper)
464     return true;
465   errs() << "Could not find TF_Output named: " + OutputSpec.name();
466   IsValid = false;
467   return IsValid;
468 }
469 
470 Optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
471   if (!isValid())
472     return None;
473   std::unique_ptr<EvaluationResultImpl> Ret =
474       std::make_unique<EvaluationResultImpl>(Impl->OutputSize());
475   auto Status = createTFStatus();
476   Impl->evaluate(Ret->getOutput().data(), Status.get());
477   if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
478     errs() << TF_Message(Status.get());
479     Impl.reset();
480     return None;
481   }
482   return EvaluationResult(std::move(Ret));
483 }
484 
485 void TFModelEvaluatorImpl::initInput(size_t Index, TF_DataType Type,
486                                      const std::vector<int64_t> &Dimensions) {
487   int64_t TotalSize = TF_DataTypeSize(Type);
488   for (auto &D : Dimensions)
489     TotalSize *= D;
490 
491   Input[Index] =
492       TF_AllocateTensor(Type, Dimensions.data(), Dimensions.size(), TotalSize);
493   std::memset(TF_TensorData(Input[Index]), 0, TotalSize);
494 }
495 
496 void *TFModelEvaluator::getUntypedInput(size_t Index) {
497   return TF_TensorData(Impl->getInput()[Index]);
498 }
499 
500 TFModelEvaluator::EvaluationResult::EvaluationResult(
501     std::unique_ptr<EvaluationResultImpl> Impl)
502     : Impl(std::move(Impl)) {}
503 
504 TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
505     : Impl(std::move(Other.Impl)) {}
506 
507 TFModelEvaluator::EvaluationResult &
508 TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
509   Impl = std::move(Other.Impl);
510   return *this;
511 }
512 
513 void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
514   return TF_TensorData(Impl->getOutput()[Index]);
515 }
516 
517 const void *
518 TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
519   return TF_TensorData(Impl->getOutput()[Index]);
520 }
521 
522 #define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
523   template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
524 
525 SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
526 
527 #undef TFUTILS_GETDATATYPE_IMPL
528 
529 TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
530 TFModelEvaluator::~TFModelEvaluator() {}
531 
532 Logger::Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs,
533                const TensorSpec &RewardSpec, bool IncludeReward)
534     : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
535       IncludeReward(IncludeReward),
536       LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
537                                                   IncludeReward)) {}
538 
539 Logger::~Logger() {}
540 
541 #define LOG_REWARD(NAME, TYPE)                                                 \
542   void Logger::log##NAME##Reward(TYPE Value) {                                 \
543     assert(IncludeReward);                                                     \
544     LoggerData->logReward(Value);                                              \
545   }
546 
547 LOG_REWARD(Float, float)
548 LOG_REWARD(Int32, int32_t)
549 LOG_REWARD(Int64, int64_t)
550 #undef LOG_REWARD
551 
552 #define LOG_FINAL_REWARD(NAME, TYPE)                                           \
553   void Logger::log##NAME##FinalReward(TYPE Value) {                            \
554     assert(RewardSpec.isElementType<TYPE>());                                  \
555     for (size_t I = 1; I < LoggerData->getNrRecords(); ++I)                    \
556       log##NAME##Reward(0);                                                    \
557     log##NAME##Reward(Value);                                                  \
558   }
559 
560 LOG_FINAL_REWARD(Float, float)
561 LOG_FINAL_REWARD(Int32, int32_t)
562 LOG_FINAL_REWARD(Int64, int64_t)
563 #undef LOG_FINAL_REWARD
564 
565 void Logger::logFloatValue(size_t FeatureID, const float *Value) {
566   assert(FeatureSpecs[FeatureID].Spec.isElementType<float>());
567   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
568 }
569 
570 void Logger::logInt64Value(size_t FeatureID, const int64_t *Value) {
571   assert(FeatureSpecs[FeatureID].Spec.isElementType<int64_t>());
572   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
573 }
574 
575 void Logger::logInt32Value(size_t FeatureID, const int32_t *Value) {
576   assert(FeatureSpecs[FeatureID].Spec.isElementType<int32_t>());
577   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
578 }
579 
580 void Logger::logSpecifiedTensorValue(size_t FeatureID, const char *RawData) {
581   const auto &Spec = FeatureSpecs[FeatureID].Spec;
582   char *Buff = addEntryAndGetFloatOrInt64Buffer(FeatureID);
583   if (Spec.isElementType<int32_t>())
584     for (size_t I = 0; I < Spec.getElementCount(); ++I)
585       (reinterpret_cast<int64_t *>(Buff))[I] =
586           static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData))[I]);
587   else if (Spec.isElementType<int64_t>() || Spec.isElementType<float>())
588     std::memcpy(Buff, RawData,
589                 Spec.getElementCount() * Spec.getElementByteSize());
590   else
591     llvm_unreachable("Unsupported tensor type");
592 }
593 
594 char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) {
595   return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID));
596 }
597 
598 void Logger::flush(std::string *Str) { LoggerData->flush(Str); }
599 
600 void Logger::flush(raw_ostream &OS) {
601   std::string Buff;
602   LoggerData->flush(&Buff);
603   OS << Buff;
604 }
605 
606 void Logger::flushLogs(raw_ostream &OS,
607                        const StringMap<std::unique_ptr<Logger>> &Loggers) {
608   google::protobuf::Struct Msg;
609   for (const auto &NamedLogger : Loggers) {
610     tensorflow::SequenceExample SE;
611     const auto &Logger = NamedLogger.second;
612     std::string Unencoded;
613     if (Logger->LoggerData->getNrRecords() > 0)
614       Logger->flush(&Unencoded);
615 
616     (*Msg.mutable_fields())[NamedLogger.first().str()]
617         .mutable_string_value()
618         ->append(ProtobufTextMode ? Unencoded : encodeBase64(Unencoded));
619   }
620 
621   std::string OutStr;
622   serialize(Msg, &OutStr);
623   OS << OutStr;
624 }
625 #endif // defined(LLVM_HAVE_TF_API)
626