#include "components/assist_ranker/quantized_nn_classifier.h"
#include "components/assist_ranker/nn_classifier.h"
#include "components/assist_ranker/nn_classifier_test_util.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace assist_ranker {
namespace quantized_nn_classifier {
namespace {
using ::google::protobuf::RepeatedFieldBackInserter;
using ::google::protobuf::RepeatedPtrField;
using ::std::copy;
using ::std::vector;
void CreateLayer(const vector<int>& biases,
const vector<vector<int>>& weights,
float low,
float high,
QuantizedNNLayer* layer) {
layer->set_biases(std::string(biases.begin(), biases.end()));
for (const auto& i : weights) {
layer->mutable_weights()->Add(std::string(i.begin(), i.end()));
}
layer->set_low(low);
layer->set_high(high);
}
QuantizedNNClassifierModel CreateModel(
const vector<int>& hidden_biases,
const vector<vector<int>>& hidden_weights,
const vector<int>& logits_biases,
const vector<vector<int>>& logits_weights,
float low,
float high) {
QuantizedNNClassifierModel model;
CreateLayer(hidden_biases, hidden_weights, low, high,
model.mutable_hidden_layer());
CreateLayer(logits_biases, logits_weights, low, high,
model.mutable_logits_layer());
return model;
}
TEST(QuantizedNNClassifierTest, Dequantize) {
const QuantizedNNClassifierModel quantized = CreateModel(
{{8, 16, 32}},
{{2, 4, 6}, {10, 4, 8}},
{2},
{{4}, {2}, {6}},
0,
128);
ASSERT_TRUE(Validate(quantized));
const NNClassifierModel model = Dequantize(quantized);
const NNClassifierModel expected = nn_classifier::CreateModel(
{{4, 8, 16}},
{{1, 2, 3}, {5, 2, 4}},
{1},
{{2}, {1}, {3}});
EXPECT_EQ(model.SerializeAsString(), expected.SerializeAsString());
}
TEST(QuantizedNNClassifierTest, XorTest) {
const QuantizedNNClassifierModel quantized = CreateModel(
{{110, 139, 175, 55, 106}},
{{228, 127, 97, 217, 158}, {55, 219, 80, 199, 152}},
{74},
{{255}, {211}, {53}, {0}, {86}},
-2.96390629,
2.8636384);
ASSERT_TRUE(Validate(quantized));
const NNClassifierModel model = Dequantize(quantized);
ASSERT_TRUE(nn_classifier::Validate(model));
EXPECT_TRUE(nn_classifier::CheckInference(model, {0, 0}, {-2.7032}));
EXPECT_TRUE(nn_classifier::CheckInference(model, {0, 1}, {2.80681}));
EXPECT_TRUE(nn_classifier::CheckInference(model, {1, 0}, {2.64435}));
EXPECT_TRUE(nn_classifier::CheckInference(model, {1, 1}, {-3.17825}));
}
TEST(QuantizedNNClassifierTest, ValidateQuantizedNNClassifierModel) {
QuantizedNNClassifierModel model;
EXPECT_FALSE(Validate(model));
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}},
0, 1);
EXPECT_TRUE(Validate(model));
model =
CreateModel({0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
model =
CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0}}, {0}, {{0}, {0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
model =
CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {}, {{0}, {0}, {0}}, 0, 1);
EXPECT_FALSE(Validate(model));
model = CreateModel({0, 0, 0}, {{0, 0, 0}, {0, 0, 0}}, {0}, {{0}, {0}, {0}},
1, 0);
EXPECT_FALSE(Validate(model));
}
}
}
}