#ifndef MEDIA_GPU_TEST_VIDEO_FRAME_VALIDATOR_H_
#define MEDIA_GPU_TEST_VIDEO_FRAME_VALIDATOR_H_
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "base/files/file.h"
#include "base/functional/callback.h"
#include "base/memory/scoped_refptr.h"
#include "base/sequence_checker.h"
#include "base/synchronization/condition_variable.h"
#include "base/synchronization/lock.h"
#include "base/threading/thread.h"
#include "base/threading/thread_checker.h"
#include "media/base/video_types.h"
#include "media/gpu/test/video_frame_helpers.h"
#include "ui/gfx/geometry/rect.h"
namespace gpu {
class TestSharedImageInterface;
}
namespace media {
class VideoFrame;
class VideoFrameMapper;
namespace test {
class VideoFrameValidator : public VideoFrameProcessor {
public:
enum class ValidationMode {
kThreshold,
kAverage,
};
using GetModelFrameCB =
base::RepeatingCallback<scoped_refptr<const VideoFrame>(size_t)>;
using CropHelper =
base::RepeatingCallback<gfx::Rect(const VideoFrame& frame)>;
VideoFrameValidator(const VideoFrameValidator&) = delete;
VideoFrameValidator& operator=(const VideoFrameValidator&) = delete;
~VideoFrameValidator() override;
void PrintMismatchedFramesInfo() const;
size_t GetMismatchedFramesCount() const;
void ProcessVideoFrame(scoped_refptr<const VideoFrame> video_frame,
size_t frame_index) final;
bool WaitUntilDone() final;
protected:
struct MismatchedFrameInfo {
MismatchedFrameInfo(size_t frame_index) : frame_index(frame_index) {}
virtual ~MismatchedFrameInfo() = default;
virtual void Print() const = 0;
size_t frame_index;
};
VideoFrameValidator(
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor,
CropHelper crop_helper);
bool Initialize();
SEQUENCE_CHECKER(validator_thread_sequence_checker_);
bool ShouldCrop() const { return static_cast<bool>(crop_helper_); }
scoped_refptr<VideoFrame> CloneAndCropFrame(
scoped_refptr<const VideoFrame> frame) const;
private:
void CleanUpOnValidatorThread();
void Destroy();
void ProcessVideoFrameTask(scoped_refptr<const VideoFrame> video_frame,
size_t frame_index);
virtual std::unique_ptr<MismatchedFrameInfo> Validate(
scoped_refptr<const VideoFrame> frame,
size_t frame_index) = 0;
virtual bool Passed() const;
std::unique_ptr<VideoFrameMapper> video_frame_mapper_;
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor_;
scoped_refptr<gpu::TestSharedImageInterface> test_sii_;
const CropHelper crop_helper_;
size_t num_frames_validating_ GUARDED_BY(frame_validator_lock_);
std::vector<std::unique_ptr<MismatchedFrameInfo>> mismatched_frames_
GUARDED_BY(frame_validator_lock_);
base::Thread frame_validator_thread_;
mutable base::Lock frame_validator_lock_;
mutable base::ConditionVariable frame_validator_cv_;
SEQUENCE_CHECKER(validator_sequence_checker_);
};
gfx::Rect BottomRowCrop(int row_height, const VideoFrame& frame);
constexpr int kDefaultBottomRowCropHeight = 2;
class MD5VideoFrameValidator : public VideoFrameValidator {
public:
static std::unique_ptr<MD5VideoFrameValidator> Create(
const std::vector<std::string>& expected_frame_checksums,
VideoPixelFormat validation_format = PIXEL_FORMAT_I420,
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor = nullptr,
CropHelper crop_helper = CropHelper());
~MD5VideoFrameValidator() override;
private:
struct MD5MismatchedFrameInfo;
MD5VideoFrameValidator(
const std::vector<std::string>& expected_frame_checksums,
VideoPixelFormat validation_format,
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor,
CropHelper crop_helper);
MD5VideoFrameValidator(const MD5VideoFrameValidator&) = delete;
MD5VideoFrameValidator& operator=(const MD5VideoFrameValidator&) = delete;
std::unique_ptr<MismatchedFrameInfo> Validate(
scoped_refptr<const VideoFrame> frame,
size_t frame_index) override;
std::string ComputeMD5FromVideoFrame(const VideoFrame& video_frame) const;
const std::vector<std::string> expected_frame_checksums_;
const VideoPixelFormat validation_format_;
};
class RawVideoFrameValidator : public VideoFrameValidator {
public:
constexpr static uint8_t kDefaultTolerance = 4;
static std::unique_ptr<RawVideoFrameValidator> Create(
const GetModelFrameCB& get_model_frame_cb,
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor = nullptr,
uint8_t tolerance = kDefaultTolerance,
CropHelper crop_helper = CropHelper());
~RawVideoFrameValidator() override;
private:
struct RawMismatchedFrameInfo;
RawVideoFrameValidator(
const GetModelFrameCB& get_model_frame_cb,
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor,
uint8_t tolerance,
CropHelper crop_helper);
std::unique_ptr<MismatchedFrameInfo> Validate(
scoped_refptr<const VideoFrame> frame,
size_t frame_index) override;
const GetModelFrameCB get_model_frame_cb_;
const uint8_t tolerance_;
};
class PSNRVideoFrameValidator : public VideoFrameValidator {
public:
constexpr static double kDefaultTolerance = 25.0;
static std::unique_ptr<PSNRVideoFrameValidator> Create(
const GetModelFrameCB& get_model_frame_cb,
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor = nullptr,
ValidationMode validation_mode = ValidationMode::kThreshold,
double tolerance = kDefaultTolerance,
CropHelper crop_helper = CropHelper());
const std::map<size_t, double>& GetPSNRValues() const { return psnr_; }
~PSNRVideoFrameValidator() override;
private:
struct PSNRMismatchedFrameInfo;
PSNRVideoFrameValidator(
const GetModelFrameCB& get_model_frame_cb,
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor,
ValidationMode validation_mode,
double tolerance,
CropHelper crop_helper);
std::unique_ptr<MismatchedFrameInfo> Validate(
scoped_refptr<const VideoFrame> frame,
size_t frame_index) override;
bool Passed() const override;
const GetModelFrameCB get_model_frame_cb_;
const CropHelper crop_helper_;
const double tolerance_;
const ValidationMode validation_mode_;
std::map<size_t, double> psnr_;
};
class SSIMVideoFrameValidator : public VideoFrameValidator {
public:
constexpr static double kDefaultTolerance = 0.70;
static std::unique_ptr<SSIMVideoFrameValidator> Create(
const GetModelFrameCB& get_model_frame_cb,
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor = nullptr,
ValidationMode validation_mode = ValidationMode::kThreshold,
double tolerance = kDefaultTolerance,
CropHelper crop_helper = CropHelper());
const std::map<size_t, double>& GetSSIMValues() const { return ssim_; }
~SSIMVideoFrameValidator() override;
private:
struct SSIMMismatchedFrameInfo;
SSIMVideoFrameValidator(
const GetModelFrameCB& get_model_frame_cb,
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor,
ValidationMode validation_mode,
double tolerance,
CropHelper crop_helper);
std::unique_ptr<MismatchedFrameInfo> Validate(
scoped_refptr<const VideoFrame> frame,
size_t frame_index) override;
bool Passed() const override;
const GetModelFrameCB get_model_frame_cb_;
const double tolerance_;
const ValidationMode validation_mode_;
std::map<size_t, double> ssim_;
};
class LogLikelihoodRatioVideoFrameValidator : public VideoFrameValidator {
public:
constexpr static double kDefaultTolerance = 1.015;
static std::unique_ptr<LogLikelihoodRatioVideoFrameValidator> Create(
const GetModelFrameCB& get_model_frame_cb,
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor = nullptr,
ValidationMode validation_mode = ValidationMode::kThreshold,
double tolerance = kDefaultTolerance,
CropHelper crop_helper = CropHelper());
const std::map<size_t, double>& get_log_likelihood_ratio_values() const {
return log_likelihood_ratios_;
}
~LogLikelihoodRatioVideoFrameValidator() override;
private:
struct LogLikelihoodRatioMismatchedFrameInfo;
LogLikelihoodRatioVideoFrameValidator(
const GetModelFrameCB& get_model_frame_cb,
std::unique_ptr<VideoFrameProcessor> corrupt_frame_processor,
ValidationMode validation_mode,
double tolerance,
CropHelper crop_helper);
std::unique_ptr<MismatchedFrameInfo> Validate(
scoped_refptr<const VideoFrame> frame,
size_t frame_index) override;
const GetModelFrameCB get_model_frame_cb_;
const double tolerance_;
std::map<size_t, double> log_likelihood_ratios_;
};
}
}
#endif