mirror of
https://github.com/Qortal/Brooklyn.git
synced 2025-02-01 07:42:18 +00:00
199 lines
7.1 KiB
C++
199 lines
7.1 KiB
C++
//
|
|
// Copyright © 2017 Arm Ltd. All rights reserved.
|
|
// SPDX-License-Identifier: MIT
|
|
//
|
|
#pragma once
|
|
|
|
#include "InferenceTest.hpp"
|
|
#include "DeepSpeechV1Database.hpp"
|
|
|
|
#include <armnn/utility/Assert.hpp>
|
|
#include <armnn/utility/IgnoreUnused.hpp>
|
|
#include <armnnUtils/FloatingPointComparison.hpp>
|
|
|
|
#include <vector>
|
|
|
|
namespace
|
|
{
|
|
|
|
template<typename Model>
|
|
class DeepSpeechV1TestCase : public InferenceModelTestCase<Model>
|
|
{
|
|
public:
|
|
DeepSpeechV1TestCase(Model& model,
|
|
unsigned int testCaseId,
|
|
const DeepSpeechV1TestCaseData& testCaseData)
|
|
: InferenceModelTestCase<Model>(model,
|
|
testCaseId,
|
|
{ testCaseData.m_InputData.m_InputSeq,
|
|
testCaseData.m_InputData.m_StateH,
|
|
testCaseData.m_InputData.m_StateC},
|
|
{ k_OutputSize1, k_OutputSize2, k_OutputSize3 })
|
|
, m_ExpectedOutputs({testCaseData.m_ExpectedOutputData.m_InputSeq, testCaseData.m_ExpectedOutputData.m_StateH,
|
|
testCaseData.m_ExpectedOutputData.m_StateC})
|
|
{}
|
|
|
|
TestCaseResult ProcessResult(const InferenceTestOptions& options) override
|
|
{
|
|
armnn::IgnoreUnused(options);
|
|
const std::vector<float>& output1 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[0]); // logits
|
|
ARMNN_ASSERT(output1.size() == k_OutputSize1);
|
|
|
|
const std::vector<float>& output2 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[1]); // new_state_c
|
|
ARMNN_ASSERT(output2.size() == k_OutputSize2);
|
|
|
|
const std::vector<float>& output3 = mapbox::util::get<std::vector<float>>(this->GetOutputs()[2]); // new_state_h
|
|
ARMNN_ASSERT(output3.size() == k_OutputSize3);
|
|
|
|
// Check each output to see whether it is the expected value
|
|
for (unsigned int j = 0u; j < output1.size(); j++)
|
|
{
|
|
if(!armnnUtils::within_percentage_tolerance(output1[j], m_ExpectedOutputs.m_InputSeq[j]))
|
|
{
|
|
ARMNN_LOG(error) << "InputSeq for Lstm " << this->GetTestCaseId() <<
|
|
" is incorrect at" << j;
|
|
return TestCaseResult::Failed;
|
|
}
|
|
}
|
|
|
|
for (unsigned int j = 0u; j < output2.size(); j++)
|
|
{
|
|
if(!armnnUtils::within_percentage_tolerance(output2[j], m_ExpectedOutputs.m_StateH[j]))
|
|
{
|
|
ARMNN_LOG(error) << "StateH for Lstm " << this->GetTestCaseId() <<
|
|
" is incorrect";
|
|
return TestCaseResult::Failed;
|
|
}
|
|
}
|
|
|
|
for (unsigned int j = 0u; j < output3.size(); j++)
|
|
{
|
|
if(!armnnUtils::within_percentage_tolerance(output3[j], m_ExpectedOutputs.m_StateC[j]))
|
|
{
|
|
ARMNN_LOG(error) << "StateC for Lstm " << this->GetTestCaseId() <<
|
|
" is incorrect";
|
|
return TestCaseResult::Failed;
|
|
}
|
|
}
|
|
return TestCaseResult::Ok;
|
|
}
|
|
|
|
private:
|
|
|
|
static constexpr unsigned int k_OutputSize1 = 464u;
|
|
static constexpr unsigned int k_OutputSize2 = 2048u;
|
|
static constexpr unsigned int k_OutputSize3 = 2048u;
|
|
|
|
LstmInput m_ExpectedOutputs;
|
|
};
|
|
|
|
template <typename Model>
|
|
class DeepSpeechV1TestCaseProvider : public IInferenceTestCaseProvider
|
|
{
|
|
public:
|
|
template <typename TConstructModelCallable>
|
|
explicit DeepSpeechV1TestCaseProvider(TConstructModelCallable constructModel)
|
|
: m_ConstructModel(constructModel)
|
|
{}
|
|
|
|
virtual void AddCommandLineOptions(cxxopts::Options& options, std::vector<std::string>& required) override
|
|
{
|
|
options
|
|
.allow_unrecognised_options()
|
|
.add_options()
|
|
("s,input-seq-dir", "Path to directory containing test data for m_InputSeq",
|
|
cxxopts::value<std::string>(m_InputSeqDir))
|
|
("h,prev-state-h-dir", "Path to directory containing test data for m_PrevStateH",
|
|
cxxopts::value<std::string>(m_PrevStateHDir))
|
|
("c,prev-state-c-dir", "Path to directory containing test data for m_PrevStateC",
|
|
cxxopts::value<std::string>(m_PrevStateCDir))
|
|
("l,logits-dir", "Path to directory containing test data for m_Logits",
|
|
cxxopts::value<std::string>(m_LogitsDir))
|
|
("H,new-state-h-dir", "Path to directory containing test data for m_NewStateH",
|
|
cxxopts::value<std::string>(m_NewStateHDir))
|
|
("C,new-state-c-dir", "Path to directory containing test data for m_NewStateC",
|
|
cxxopts::value<std::string>(m_NewStateCDir));
|
|
|
|
required.insert(required.end(), {"input-seq-dir", "prev-state-h-dir", "prev-state-c-dir", "logits-dir",
|
|
"new-state-h-dir", "new-state-c-dir"});
|
|
|
|
Model::AddCommandLineOptions(options, m_ModelCommandLineOptions, required);
|
|
}
|
|
|
|
virtual bool ProcessCommandLineOptions(const InferenceTestOptions &commonOptions) override
|
|
{
|
|
if (!ValidateDirectory(m_InputSeqDir))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (!ValidateDirectory(m_PrevStateCDir))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (!ValidateDirectory(m_PrevStateHDir))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (!ValidateDirectory(m_LogitsDir))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (!ValidateDirectory(m_NewStateCDir))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
if (!ValidateDirectory(m_NewStateHDir))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
m_Model = m_ConstructModel(commonOptions, m_ModelCommandLineOptions);
|
|
if (!m_Model)
|
|
{
|
|
return false;
|
|
}
|
|
m_Database = std::make_unique<DeepSpeechV1Database>(m_InputSeqDir.c_str(), m_PrevStateHDir.c_str(),
|
|
m_PrevStateCDir.c_str(), m_LogitsDir.c_str(),
|
|
m_NewStateHDir.c_str(), m_NewStateCDir.c_str());
|
|
if (!m_Database)
|
|
{
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
std::unique_ptr<IInferenceTestCase> GetTestCase(unsigned int testCaseId) override
|
|
{
|
|
std::unique_ptr<DeepSpeechV1TestCaseData> testCaseData = m_Database->GetTestCaseData(testCaseId);
|
|
if (!testCaseData)
|
|
{
|
|
return nullptr;
|
|
}
|
|
|
|
return std::make_unique<DeepSpeechV1TestCase<Model>>(*m_Model, testCaseId, *testCaseData);
|
|
}
|
|
|
|
private:
|
|
typename Model::CommandLineOptions m_ModelCommandLineOptions;
|
|
std::function<std::unique_ptr<Model>(const InferenceTestOptions&,
|
|
typename Model::CommandLineOptions)> m_ConstructModel;
|
|
std::unique_ptr<Model> m_Model;
|
|
|
|
std::string m_InputSeqDir;
|
|
std::string m_PrevStateCDir;
|
|
std::string m_PrevStateHDir;
|
|
std::string m_LogitsDir;
|
|
std::string m_NewStateCDir;
|
|
std::string m_NewStateHDir;
|
|
|
|
std::unique_ptr<DeepSpeechV1Database> m_Database;
|
|
};
|
|
|
|
} // anonymous namespace
|