104f2712eSMircea Trofin //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
204f2712eSMircea Trofin //
304f2712eSMircea Trofin // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
404f2712eSMircea Trofin // See https://llvm.org/LICENSE.txt for license information.
504f2712eSMircea Trofin // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
604f2712eSMircea Trofin //
704f2712eSMircea Trofin //===----------------------------------------------------------------------===//
804f2712eSMircea Trofin //
904f2712eSMircea Trofin // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
1004f2712eSMircea Trofin // happens off a model that's provided from the command line and is interpreted.
1104f2712eSMircea Trofin //
1204f2712eSMircea Trofin //===----------------------------------------------------------------------===//
1304f2712eSMircea Trofin 
1404f2712eSMircea Trofin #include "llvm/Config/config.h"
1504f2712eSMircea Trofin #if defined(LLVM_HAVE_TF_API)
1604f2712eSMircea Trofin 
1704f2712eSMircea Trofin #include "llvm/Analysis/ModelUnderTrainingRunner.h"
1804f2712eSMircea Trofin 
1904f2712eSMircea Trofin using namespace llvm;
2004f2712eSMircea Trofin 
ModelUnderTrainingRunner(LLVMContext & Ctx,const std::string & ModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<LoggedFeatureSpec> & OutputSpecs)2104f2712eSMircea Trofin ModelUnderTrainingRunner::ModelUnderTrainingRunner(
2204f2712eSMircea Trofin     LLVMContext &Ctx, const std::string &ModelPath,
2304f2712eSMircea Trofin     const std::vector<TensorSpec> &InputSpecs,
2404f2712eSMircea Trofin     const std::vector<LoggedFeatureSpec> &OutputSpecs)
25*c35ad9eeSMircea Trofin     : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
26a120fdd3SMircea Trofin       OutputSpecs(OutputSpecs) {
2704f2712eSMircea Trofin   Evaluator = std::make_unique<TFModelEvaluator>(
2804f2712eSMircea Trofin       ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; },
2904f2712eSMircea Trofin       OutputSpecs.size());
3004f2712eSMircea Trofin   if (!Evaluator || !Evaluator->isValid()) {
31a81b0c97SMircea Trofin     Ctx.emitError("Failed to create saved model evaluator");
3204f2712eSMircea Trofin     Evaluator.reset();
3304f2712eSMircea Trofin     return;
3404f2712eSMircea Trofin   }
35*c35ad9eeSMircea Trofin 
36*c35ad9eeSMircea Trofin   for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
37*c35ad9eeSMircea Trofin     setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
38*c35ad9eeSMircea Trofin   }
3904f2712eSMircea Trofin }
4004f2712eSMircea Trofin 
evaluateUntyped()4104f2712eSMircea Trofin void *ModelUnderTrainingRunner::evaluateUntyped() {
4204f2712eSMircea Trofin   LastEvaluationResult = Evaluator->evaluate();
4304f2712eSMircea Trofin   if (!LastEvaluationResult.hasValue()) {
4404f2712eSMircea Trofin     Ctx.emitError("Error evaluating model.");
4504f2712eSMircea Trofin     return nullptr;
4604f2712eSMircea Trofin   }
4704f2712eSMircea Trofin   return LastEvaluationResult->getUntypedTensorValue(0);
4804f2712eSMircea Trofin }
4904f2712eSMircea Trofin 
50*c35ad9eeSMircea Trofin std::unique_ptr<ModelUnderTrainingRunner>
createAndEnsureValid(LLVMContext & Ctx,const std::string & ModelPath,StringRef DecisionName,const std::vector<TensorSpec> & InputSpecs,StringRef OutputSpecsPathOverride)51*c35ad9eeSMircea Trofin ModelUnderTrainingRunner::createAndEnsureValid(
52*c35ad9eeSMircea Trofin     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
53*c35ad9eeSMircea Trofin     const std::vector<TensorSpec> &InputSpecs,
54*c35ad9eeSMircea Trofin     StringRef OutputSpecsPathOverride) {
55*c35ad9eeSMircea Trofin   if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
56*c35ad9eeSMircea Trofin                                               OutputSpecsPathOverride))
57*c35ad9eeSMircea Trofin     return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs,
58*c35ad9eeSMircea Trofin                                 *MaybeOutputSpecs);
59*c35ad9eeSMircea Trofin   Ctx.emitError("Could not load the policy model from the provided path");
60*c35ad9eeSMircea Trofin   return nullptr;
6104f2712eSMircea Trofin }
6204f2712eSMircea Trofin 
63a120fdd3SMircea Trofin std::unique_ptr<ModelUnderTrainingRunner>
createAndEnsureValid(LLVMContext & Ctx,const std::string & ModelPath,StringRef DecisionName,const std::vector<TensorSpec> & InputSpecs,const std::vector<LoggedFeatureSpec> & OutputSpecs)64a120fdd3SMircea Trofin ModelUnderTrainingRunner::createAndEnsureValid(
65a120fdd3SMircea Trofin     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
66a120fdd3SMircea Trofin     const std::vector<TensorSpec> &InputSpecs,
67*c35ad9eeSMircea Trofin     const std::vector<LoggedFeatureSpec> &OutputSpecs) {
68a120fdd3SMircea Trofin   std::unique_ptr<ModelUnderTrainingRunner> MUTR;
69*c35ad9eeSMircea Trofin   MUTR.reset(
70*c35ad9eeSMircea Trofin       new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs));
71a120fdd3SMircea Trofin   if (MUTR && MUTR->isValid())
72a120fdd3SMircea Trofin     return MUTR;
73a120fdd3SMircea Trofin 
74*c35ad9eeSMircea Trofin   Ctx.emitError("Could not load or create model evaluator.");
75a120fdd3SMircea Trofin   return nullptr;
76a120fdd3SMircea Trofin }
77a120fdd3SMircea Trofin 
7804f2712eSMircea Trofin #endif // defined(LLVM_HAVE_TF_API)
79