forked from Qortal/Brooklyn
253 lines
7.6 KiB
C++
253 lines
7.6 KiB
C++
//
|
|
// Copyright © 2017 Arm Ltd. All rights reserved.
|
|
// SPDX-License-Identifier: MIT
|
|
//
|
|
#include "InferenceTest.hpp"
|
|
|
|
#include <armnn/utility/Assert.hpp>
|
|
#include <armnnUtils/Filesystem.hpp>
|
|
|
|
#include "../src/armnn/Profiling.hpp"
|
|
#include <cxxopts/cxxopts.hpp>
|
|
|
|
#include <fstream>
|
|
#include <iostream>
|
|
#include <iomanip>
|
|
#include <array>
|
|
|
|
using namespace std;
|
|
using namespace std::chrono;
|
|
using namespace armnn::test;
|
|
|
|
namespace armnn
|
|
{
|
|
namespace test
|
|
{
|
|
/// Parse the command line of an ArmNN (or referencetests) inference test program.
|
|
/// \return false if any error occurred during options processing, otherwise true
|
|
bool ParseCommandLine(int argc, char** argv, IInferenceTestCaseProvider& testCaseProvider,
|
|
InferenceTestOptions& outParams)
|
|
{
|
|
cxxopts::Options options("InferenceTest", "Inference iteration parameters");
|
|
|
|
try
|
|
{
|
|
// Adds generic options needed for all inference tests.
|
|
options
|
|
.allow_unrecognised_options()
|
|
.add_options()
|
|
("h,help", "Display help messages")
|
|
("i,iterations", "Sets the number of inferences to perform. If unset, will only be run once.",
|
|
cxxopts::value<unsigned int>(outParams.m_IterationCount)->default_value("0"))
|
|
("inference-times-file",
|
|
"If non-empty, each individual inference time will be recorded and output to this file",
|
|
cxxopts::value<std::string>(outParams.m_InferenceTimesFile)->default_value(""))
|
|
("e,event-based-profiling", "Enables built in profiler. If unset, defaults to off.",
|
|
cxxopts::value<bool>(outParams.m_EnableProfiling)->default_value("0"));
|
|
|
|
std::vector<std::string> required; //to be passed as reference to derived inference tests
|
|
|
|
// Adds options specific to the ITestCaseProvider.
|
|
testCaseProvider.AddCommandLineOptions(options, required);
|
|
|
|
auto result = options.parse(argc, argv);
|
|
|
|
if (result.count("help"))
|
|
{
|
|
std::cout << options.help() << std::endl;
|
|
return false;
|
|
}
|
|
|
|
CheckRequiredOptions(result, required);
|
|
|
|
}
|
|
catch (const cxxopts::OptionException& e)
|
|
{
|
|
std::cerr << e.what() << std::endl << options.help() << std::endl;
|
|
return false;
|
|
}
|
|
catch (const std::exception& e)
|
|
{
|
|
ARMNN_ASSERT_MSG(false, "Caught unexpected exception");
|
|
std::cerr << "Fatal internal error: " << e.what() << std::endl;
|
|
return false;
|
|
}
|
|
|
|
if (!testCaseProvider.ProcessCommandLineOptions(outParams))
|
|
{
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool ValidateDirectory(std::string& dir)
|
|
{
|
|
if (dir.empty())
|
|
{
|
|
std::cerr << "No directory specified" << std::endl;
|
|
return false;
|
|
}
|
|
|
|
if (dir[dir.length() - 1] != '/')
|
|
{
|
|
dir += "/";
|
|
}
|
|
|
|
if (!fs::exists(dir))
|
|
{
|
|
std::cerr << "Given directory " << dir << " does not exist" << std::endl;
|
|
return false;
|
|
}
|
|
|
|
if (!fs::is_directory(dir))
|
|
{
|
|
std::cerr << "Given directory [" << dir << "] is not a directory" << std::endl;
|
|
return false;
|
|
}
|
|
|
|
return true;
|
|
}
|
|
|
|
bool InferenceTest(const InferenceTestOptions& params,
|
|
const std::vector<unsigned int>& defaultTestCaseIds,
|
|
IInferenceTestCaseProvider& testCaseProvider)
|
|
{
|
|
#if !defined (NDEBUG)
|
|
if (params.m_IterationCount > 0) // If just running a few select images then don't bother to warn.
|
|
{
|
|
ARMNN_LOG(warning) << "Performance test running in DEBUG build - results may be inaccurate.";
|
|
}
|
|
#endif
|
|
|
|
double totalTime = 0;
|
|
unsigned int nbProcessed = 0;
|
|
bool success = true;
|
|
|
|
// Opens the file to write inference times too, if needed.
|
|
ofstream inferenceTimesFile;
|
|
const bool recordInferenceTimes = !params.m_InferenceTimesFile.empty();
|
|
if (recordInferenceTimes)
|
|
{
|
|
inferenceTimesFile.open(params.m_InferenceTimesFile.c_str(), ios_base::trunc | ios_base::out);
|
|
if (!inferenceTimesFile.good())
|
|
{
|
|
ARMNN_LOG(error) << "Failed to open inference times file for writing: "
|
|
<< params.m_InferenceTimesFile;
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Create a profiler and register it for the current thread.
|
|
std::unique_ptr<IProfiler> profiler = std::make_unique<IProfiler>();
|
|
ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
|
|
|
|
// Enable profiling if requested.
|
|
profiler->EnableProfiling(params.m_EnableProfiling);
|
|
|
|
// Run a single test case to 'warm-up' the model. The first one can sometimes take up to 10x longer
|
|
std::unique_ptr<IInferenceTestCase> warmupTestCase = testCaseProvider.GetTestCase(0);
|
|
if (warmupTestCase == nullptr)
|
|
{
|
|
ARMNN_LOG(error) << "Failed to load test case";
|
|
return false;
|
|
}
|
|
|
|
try
|
|
{
|
|
warmupTestCase->Run();
|
|
}
|
|
catch (const TestFrameworkException& testError)
|
|
{
|
|
ARMNN_LOG(error) << testError.what();
|
|
return false;
|
|
}
|
|
|
|
const unsigned int nbTotalToProcess = params.m_IterationCount > 0 ? params.m_IterationCount
|
|
: static_cast<unsigned int>(defaultTestCaseIds.size());
|
|
|
|
for (; nbProcessed < nbTotalToProcess; nbProcessed++)
|
|
{
|
|
const unsigned int testCaseId = params.m_IterationCount > 0 ? nbProcessed : defaultTestCaseIds[nbProcessed];
|
|
std::unique_ptr<IInferenceTestCase> testCase = testCaseProvider.GetTestCase(testCaseId);
|
|
|
|
if (testCase == nullptr)
|
|
{
|
|
ARMNN_LOG(error) << "Failed to load test case";
|
|
return false;
|
|
}
|
|
|
|
time_point<high_resolution_clock> predictStart;
|
|
time_point<high_resolution_clock> predictEnd;
|
|
|
|
TestCaseResult result = TestCaseResult::Ok;
|
|
|
|
try
|
|
{
|
|
predictStart = high_resolution_clock::now();
|
|
|
|
testCase->Run();
|
|
|
|
predictEnd = high_resolution_clock::now();
|
|
|
|
// duration<double> will convert the time difference into seconds as a double by default.
|
|
double timeTakenS = duration<double>(predictEnd - predictStart).count();
|
|
totalTime += timeTakenS;
|
|
|
|
// Outputss inference times, if needed.
|
|
if (recordInferenceTimes)
|
|
{
|
|
inferenceTimesFile << testCaseId << " " << (timeTakenS * 1000.0) << std::endl;
|
|
}
|
|
|
|
result = testCase->ProcessResult(params);
|
|
|
|
}
|
|
catch (const TestFrameworkException& testError)
|
|
{
|
|
ARMNN_LOG(error) << testError.what();
|
|
result = TestCaseResult::Abort;
|
|
}
|
|
|
|
switch (result)
|
|
{
|
|
case TestCaseResult::Ok:
|
|
break;
|
|
case TestCaseResult::Abort:
|
|
return false;
|
|
case TestCaseResult::Failed:
|
|
// This test failed so we will fail the entire program eventually, but keep going for now.
|
|
success = false;
|
|
break;
|
|
default:
|
|
ARMNN_ASSERT_MSG(false, "Unexpected TestCaseResult");
|
|
return false;
|
|
}
|
|
}
|
|
|
|
const double averageTimePerTestCaseMs = totalTime / nbProcessed * 1000.0f;
|
|
|
|
ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
|
|
"Total time for " << nbProcessed << " test cases: " << totalTime << " seconds";
|
|
ARMNN_LOG(info) << std::fixed << std::setprecision(3) <<
|
|
"Average time per test case: " << averageTimePerTestCaseMs << " ms";
|
|
|
|
// if profiling is enabled print out the results
|
|
if (profiler && profiler->IsProfilingEnabled())
|
|
{
|
|
profiler->Print(std::cout);
|
|
}
|
|
|
|
if (!success)
|
|
{
|
|
ARMNN_LOG(error) << "One or more test cases failed";
|
|
return false;
|
|
}
|
|
|
|
return testCaseProvider.OnInferenceTestFinished();
|
|
}
|
|
|
|
} // namespace test
|
|
|
|
} // namespace armnn
|