1 //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
10 // happens off a model that's provided from the command line and is interpreted.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/Config/config.h"
15 #if defined(LLVM_HAVE_TF_API)
16
17 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
18
19 using namespace llvm;
20
ModelUnderTrainingRunner(LLVMContext & Ctx,const std::string & ModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<LoggedFeatureSpec> & OutputSpecs)21 ModelUnderTrainingRunner::ModelUnderTrainingRunner(
22 LLVMContext &Ctx, const std::string &ModelPath,
23 const std::vector<TensorSpec> &InputSpecs,
24 const std::vector<LoggedFeatureSpec> &OutputSpecs)
25 : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
26 OutputSpecs(OutputSpecs) {
27 Evaluator = std::make_unique<TFModelEvaluator>(
28 ModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I].Spec; },
29 OutputSpecs.size());
30 if (!Evaluator || !Evaluator->isValid()) {
31 Ctx.emitError("Failed to create saved model evaluator");
32 Evaluator.reset();
33 return;
34 }
35
36 for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
37 setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
38 }
39 }
40
evaluateUntyped()41 void *ModelUnderTrainingRunner::evaluateUntyped() {
42 LastEvaluationResult = Evaluator->evaluate();
43 if (!LastEvaluationResult.hasValue()) {
44 Ctx.emitError("Error evaluating model.");
45 return nullptr;
46 }
47 return LastEvaluationResult->getUntypedTensorValue(0);
48 }
49
50 std::unique_ptr<ModelUnderTrainingRunner>
createAndEnsureValid(LLVMContext & Ctx,const std::string & ModelPath,StringRef DecisionName,const std::vector<TensorSpec> & InputSpecs,StringRef OutputSpecsPathOverride)51 ModelUnderTrainingRunner::createAndEnsureValid(
52 LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
53 const std::vector<TensorSpec> &InputSpecs,
54 StringRef OutputSpecsPathOverride) {
55 if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
56 OutputSpecsPathOverride))
57 return createAndEnsureValid(Ctx, ModelPath, DecisionName, InputSpecs,
58 *MaybeOutputSpecs);
59 Ctx.emitError("Could not load the policy model from the provided path");
60 return nullptr;
61 }
62
63 std::unique_ptr<ModelUnderTrainingRunner>
createAndEnsureValid(LLVMContext & Ctx,const std::string & ModelPath,StringRef DecisionName,const std::vector<TensorSpec> & InputSpecs,const std::vector<LoggedFeatureSpec> & OutputSpecs)64 ModelUnderTrainingRunner::createAndEnsureValid(
65 LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
66 const std::vector<TensorSpec> &InputSpecs,
67 const std::vector<LoggedFeatureSpec> &OutputSpecs) {
68 std::unique_ptr<ModelUnderTrainingRunner> MUTR;
69 MUTR.reset(
70 new ModelUnderTrainingRunner(Ctx, ModelPath, InputSpecs, OutputSpecs));
71 if (MUTR && MUTR->isValid())
72 return MUTR;
73
74 Ctx.emitError("Could not load or create model evaluator.");
75 return nullptr;
76 }
77
78 #endif // defined(LLVM_HAVE_TF_API)
79