#include "services/device/public/cpp/hid/hid_collection.h"
#include <algorithm>
#include <limits>
#include <utility>
#include "base/format_macros.h"
#include "base/memory/ptr_util.h"
#include "base/memory/raw_ref.h"
#include "base/strings/stringprintf.h"
#include "services/device/public/cpp/hid/hid_item_state_table.h"
namespace device {
namespace {
static constexpr uint32_t kMaxItemReportSizeBits = 32;
static constexpr uint64_t kMaxReasonableReportLengthBits =
std::numeric_limits<uint16_t>::max() * 8;
static constexpr int kMaxReasonableCollectionDepth = 50;
}
HidCollection::HidCollection(HidCollection* parent,
uint32_t usage_page,
uint32_t usage,
uint32_t type)
: parent_(parent), usage_(usage, usage_page), collection_type_(type) {}
HidCollection::~HidCollection() = default;
std::vector<std::unique_ptr<HidCollection>> HidCollection::BuildCollections(
const std::vector<std::unique_ptr<HidReportDescriptorItem>>& items) {
std::vector<std::unique_ptr<HidCollection>> collections;
HidItemStateTable state;
int depth = 0;
for (const auto& current_item : items) {
switch (current_item->tag()) {
case HidReportDescriptorItem::kTagCollection:
++depth;
if (depth <= kMaxReasonableCollectionDepth)
AddCollection(*current_item, collections, state);
state.local.Reset();
break;
case HidReportDescriptorItem::kTagEndCollection:
if (depth <= kMaxReasonableCollectionDepth) {
if (state.collection)
state.collection = state.collection->parent_;
}
state.local.Reset();
if (depth > 0)
--depth;
break;
case HidReportDescriptorItem::kTagInput:
case HidReportDescriptorItem::kTagOutput:
case HidReportDescriptorItem::kTagFeature:
if (state.collection) {
auto* collection = state.collection.get();
while (collection) {
collection->AddReportItem(current_item->tag(),
current_item->GetShortData(), state);
collection = collection->parent_;
}
}
state.local.Reset();
break;
case HidReportDescriptorItem::kTagPush:
if (!state.global_stack.empty())
state.global_stack.push_back(state.global_stack.back());
break;
case HidReportDescriptorItem::kTagPop:
if (!state.global_stack.empty())
state.global_stack.pop_back();
break;
case HidReportDescriptorItem::kTagReportId:
if (state.collection) {
state.report_id = current_item->GetShortData();
auto* collection = state.collection.get();
while (collection) {
collection->report_ids_.push_back(state.report_id);
collection = collection->parent_;
}
}
break;
case HidReportDescriptorItem::kTagUsagePage:
case HidReportDescriptorItem::kTagLogicalMinimum:
case HidReportDescriptorItem::kTagLogicalMaximum:
case HidReportDescriptorItem::kTagPhysicalMinimum:
case HidReportDescriptorItem::kTagPhysicalMaximum:
case HidReportDescriptorItem::kTagUnitExponent:
case HidReportDescriptorItem::kTagUnit:
case HidReportDescriptorItem::kTagReportSize:
case HidReportDescriptorItem::kTagReportCount:
case HidReportDescriptorItem::kTagUsage:
case HidReportDescriptorItem::kTagUsageMinimum:
case HidReportDescriptorItem::kTagUsageMaximum:
case HidReportDescriptorItem::kTagDesignatorIndex:
case HidReportDescriptorItem::kTagDesignatorMinimum:
case HidReportDescriptorItem::kTagDesignatorMaximum:
case HidReportDescriptorItem::kTagStringIndex:
case HidReportDescriptorItem::kTagStringMinimum:
case HidReportDescriptorItem::kTagStringMaximum:
case HidReportDescriptorItem::kTagDelimiter:
state.SetItemValue(current_item->tag(), current_item->GetShortData(),
current_item->payload_size());
break;
default:
break;
}
}
return collections;
}
void HidCollection::AddCollection(
const HidReportDescriptorItem& item,
std::vector<std::unique_ptr<HidCollection>>& collections,
HidItemStateTable& state) {
uint32_t usage = state.local.usages.empty() ? 0 : state.local.usages.front();
uint32_t usage_page = (usage >> 16) & 0xffff;
if (usage_page == 0 && !state.global_stack.empty())
usage_page = state.global_stack.back().usage_page;
uint32_t collection_type = item.GetShortData();
auto collection = std::make_unique<HidCollection>(
state.collection, usage_page, usage, collection_type);
if (state.collection) {
state.collection->children_.push_back(std::move(collection));
state.collection = state.collection->children_.back().get();
} else {
collections.push_back(std::move(collection));
state.collection = collections.back().get();
}
}
void HidCollection::AddChildForTesting(
std::unique_ptr<HidCollection> collection) {
children_.push_back(std::move(collection));
}
void HidCollection::AddReportItem(HidReportDescriptorItem::Tag tag,
uint32_t report_info,
const HidItemStateTable& state) {
std::unordered_map<uint8_t, HidReport>* reports = nullptr;
if (tag == HidReportDescriptorItem::kTagInput)
reports = &input_reports_;
else if (tag == HidReportDescriptorItem::kTagOutput)
reports = &output_reports_;
else if (tag == HidReportDescriptorItem::kTagFeature)
reports = &feature_reports_;
else
return;
HidReport* report = nullptr;
auto find_it = reports->find(state.report_id);
if (find_it == reports->end()) {
auto emplace_result = reports->emplace(state.report_id, HidReport());
report = &emplace_result.first->second;
} else {
report = &find_it->second;
}
report->push_back(HidReportItem::Create(tag, report_info, state));
}
void HidCollection::GetMaxReportSizes(size_t* max_input_report_bits,
size_t* max_output_report_bits,
size_t* max_feature_report_bits) const {
DCHECK(max_input_report_bits);
DCHECK(max_output_report_bits);
DCHECK(max_feature_report_bits);
struct {
const raw_ref<const std::unordered_map<uint8_t, HidReport>> reports;
const raw_ref<size_t> max_report_bits;
} report_lists[]{
{ToRawRef(input_reports_), ToRawRef(*max_input_report_bits)},
{ToRawRef(output_reports_), ToRawRef(*max_output_report_bits)},
{ToRawRef(feature_reports_), ToRawRef(*max_feature_report_bits)},
};
auto collection_info = mojom::HidCollectionInfo::New();
collection_info->usage =
mojom::HidUsageAndPage::New(usage_.usage, usage_.usage_page);
collection_info->report_ids.insert(collection_info->report_ids.end(),
report_ids_.begin(), report_ids_.end());
static constexpr int kMaxLogMessages = 100;
int log_message_count = 0;
for (const auto& entry : report_lists) {
*entry.max_report_bits = 0;
for (const auto& report : *entry.reports) {
uint64_t report_bits = 0;
for (const auto& item : report.second) {
uint64_t report_size = item->GetReportSize();
if (report_size > kMaxItemReportSizeBits) {
if (log_message_count < kMaxLogMessages) {
++log_message_count;
LOG(WARNING) << base::StringPrintf(
"encountered report item with invalid report size (%" PRIu64
">%u)",
report_size, kMaxItemReportSizeBits);
} else if (log_message_count == kMaxLogMessages) {
++log_message_count;
LOG(WARNING) << "Too many invalid report items, not reporting any "
"more for this device.";
}
}
uint64_t report_count = item->GetReportCount();
uint64_t item_bits = report_size * report_count;
if (item_bits > kMaxReasonableReportLengthBits ||
report_bits > kMaxReasonableReportLengthBits - item_bits) {
report_bits = 0;
break;
}
report_bits += item_bits;
}
DCHECK_LE(report_bits, kMaxReasonableReportLengthBits);
*entry.max_report_bits =
std::max(*entry.max_report_bits, static_cast<size_t>(report_bits));
}
}
}
mojom::HidCollectionInfoPtr HidCollection::ToMojo() const {
auto collection = mojom::HidCollectionInfo::New();
struct {
const raw_ref<const std::unordered_map<uint8_t, HidReport>> in;
const raw_ref<std::vector<mojom::HidReportDescriptionPtr>> out;
} report_lists[]{
{ToRawRef(input_reports_), ToRawRef(collection->input_reports)},
{ToRawRef(output_reports_), ToRawRef(collection->output_reports)},
{ToRawRef(feature_reports_), ToRawRef(collection->feature_reports)},
};
collection->usage =
mojom::HidUsageAndPage::New(usage_.usage, usage_.usage_page);
collection->report_ids.insert(collection->report_ids.end(),
report_ids_.begin(), report_ids_.end());
collection->collection_type = collection_type_;
for (const auto& report_list : report_lists) {
for (const auto& report : *report_list.in) {
auto report_description = mojom::HidReportDescription::New();
report_description->report_id = report.first;
for (const auto& item : report.second)
report_description->items.push_back(item->ToMojo());
report_list.out->push_back(std::move(report_description));
}
}
for (const auto& child : children_)
collection->children.push_back(child->ToMojo());
return collection;
}
}