#include "sandbox/linux/bpf_dsl/codegen.h"
#include <stddef.h>
#include <stdint.h>
#include <map>
#include <string_view>
#include <utility>
#include <vector>
#include "crypto/hash.h"
#include "sandbox/linux/system_headers/linux_filter.h"
#include "testing/gtest/include/gtest/gtest.h"
namespace sandbox {
namespace {
class Hash {
public:
static const Hash kZero;
Hash() : digest_() {}
Hash(uint16_t code,
uint32_t k,
const Hash& jt = kZero,
const Hash& jf = kZero)
: digest_() {
crypto::hash::Hasher hasher(crypto::hash::HashKind::kSha256);
hasher.Update(base::byte_span_from_ref(code));
hasher.Update(base::byte_span_from_ref(k));
hasher.Update(jt.digest());
hasher.Update(jf.digest());
hasher.Finish(digest_);
}
Hash(const Hash& hash) = default;
Hash& operator=(const Hash& rhs) = default;
friend bool operator==(const Hash& lhs, const Hash& rhs) {
return lhs.digest_ == rhs.digest_;
}
friend bool operator!=(const Hash& lhs, const Hash& rhs) {
return !(lhs == rhs);
}
base::span<const uint8_t> digest() const { return digest_; }
private:
std::array<uint8_t, crypto::hash::kSha256Size> digest_;
};
const Hash Hash::kZero;
TEST(CodeGen, HashSanity) {
std::vector<Hash> hashes;
hashes.push_back(Hash::kZero);
for (int i = 0; i < 4; ++i) {
hashes.push_back(Hash(i & 1, i & 2));
}
for (int i = 0; i < 16; ++i) {
hashes.push_back(Hash(i & 1, i & 2, Hash(i & 4, i & 8)));
}
for (int i = 0; i < 64; ++i) {
hashes.push_back(
Hash(i & 1, i & 2, Hash(i & 4, i & 8), Hash(i & 16, i & 32)));
}
for (const Hash& a : hashes) {
for (const Hash& b : hashes) {
if (&a == &b) {
EXPECT_EQ(a, b);
} else {
EXPECT_NE(a, b);
}
}
}
}
class ProgramTest : public ::testing::Test {
public:
ProgramTest(const ProgramTest&) = delete;
ProgramTest& operator=(const ProgramTest&) = delete;
protected:
ProgramTest() : gen_(), node_hashes_() {}
CodeGen::Node MakeInstruction(uint16_t code,
uint32_t k,
CodeGen::Node jt = CodeGen::kNullNode,
CodeGen::Node jf = CodeGen::kNullNode) {
CodeGen::Node res = gen_.MakeInstruction(code, k, jt, jf);
EXPECT_NE(CodeGen::kNullNode, res);
Hash digest(code, k, Lookup(jt), Lookup(jf));
auto it = node_hashes_.insert(std::make_pair(res, digest));
EXPECT_EQ(digest, it.first->second);
return res;
}
void RunTest(CodeGen::Node head) {
CodeGen::Program program = gen_.Compile(head);
std::vector<Hash> prog_hashes(program.size());
for (size_t i = program.size(); i > 0; --i) {
const sock_filter& insn = program.at(i - 1);
Hash& hash = prog_hashes.at(i - 1);
if (BPF_CLASS(insn.code) == BPF_JMP) {
if (BPF_OP(insn.code) == BPF_JA) {
hash = prog_hashes.at(i + insn.k);
} else {
hash = Hash(insn.code, insn.k, prog_hashes.at(i + insn.jt),
prog_hashes.at(i + insn.jf));
}
} else if (BPF_CLASS(insn.code) == BPF_RET) {
hash = Hash(insn.code, insn.k);
} else {
hash = Hash(insn.code, insn.k, prog_hashes.at(i));
}
}
EXPECT_EQ(Lookup(head), prog_hashes.at(0));
}
private:
const Hash& Lookup(CodeGen::Node next) const {
if (next == CodeGen::kNullNode) {
return Hash::kZero;
}
auto it = node_hashes_.find(next);
if (it == node_hashes_.end()) {
ADD_FAILURE() << "No hash found for node " << next;
return Hash::kZero;
}
return it->second;
}
CodeGen gen_;
std::map<CodeGen::Node, Hash> node_hashes_;
};
TEST_F(ProgramTest, OneInstruction) {
CodeGen::Node head = MakeInstruction(BPF_RET + BPF_K, 0);
RunTest(head);
}
TEST_F(ProgramTest, SimpleBranch) {
CodeGen::Node head = MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 42,
MakeInstruction(BPF_RET + BPF_K, 1),
MakeInstruction(BPF_RET + BPF_K, 0));
RunTest(head);
}
TEST_F(ProgramTest, AtypicalBranch) {
CodeGen::Node ret = MakeInstruction(BPF_RET + BPF_K, 0);
CodeGen::Node head = MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 42, ret, ret);
RunTest(head);
}
TEST_F(ProgramTest, Complex) {
CodeGen::Node insn0 = MakeInstruction(BPF_RET + BPF_K, 42);
CodeGen::Node insn1 = MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 42, insn0);
CodeGen::Node insn2 = insn1;
CodeGen::Node insn3 = MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 42,
MakeInstruction(BPF_RET + BPF_K, 42));
EXPECT_EQ(insn2, insn3);
CodeGen::Node insn4 =
MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 42, insn2, insn3);
CodeGen::Node insn5 = MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 23, insn4);
CodeGen::Node insn6 =
MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 42, insn5, insn4);
RunTest(insn6);
}
TEST_F(ProgramTest, ConfusingTails) {
CodeGen::Node i7 = MakeInstruction(BPF_RET + BPF_K, 1);
CodeGen::Node i6 = MakeInstruction(BPF_RET + BPF_K, 0);
CodeGen::Node i5 = MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 1, i6, i7);
CodeGen::Node i4 = MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 0, i5);
CodeGen::Node i3 = MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 2, i4, i5);
CodeGen::Node i2 = MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 0, i3);
CodeGen::Node i1 = MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 1, i2, i3);
CodeGen::Node i0 = MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 1, i1);
RunTest(i0);
}
TEST_F(ProgramTest, ConfusingTailsBasic) {
CodeGen::Node i5 = MakeInstruction(BPF_RET + BPF_K, 1);
CodeGen::Node i4 = MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 0, i5);
CodeGen::Node i3 = MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 2, i4, i5);
CodeGen::Node i2 = MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 0, i3);
CodeGen::Node i1 = MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 1, i2, i3);
CodeGen::Node i0 = MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 1, i1);
RunTest(i0);
}
TEST_F(ProgramTest, ConfusingTailsMergeable) {
CodeGen::Node i7 = MakeInstruction(BPF_RET + BPF_K, 1);
CodeGen::Node i6 = MakeInstruction(BPF_RET + BPF_K, 0);
CodeGen::Node i5 = MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 1, i6, i7);
CodeGen::Node i4 = MakeInstruction(BPF_RET + BPF_K, 42);
CodeGen::Node i3 = MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 2, i4, i5);
CodeGen::Node i2 = MakeInstruction(BPF_RET + BPF_K, 42);
CodeGen::Node i1 = MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 1, i2, i3);
CodeGen::Node i0 = MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 1, i1);
RunTest(i0);
}
TEST_F(ProgramTest, InstructionFolding) {
CodeGen::Node a = MakeInstruction(BPF_RET + BPF_K, 0);
EXPECT_EQ(a, MakeInstruction(BPF_RET + BPF_K, 0));
CodeGen::Node b = MakeInstruction(BPF_RET + BPF_K, 1);
EXPECT_EQ(a, MakeInstruction(BPF_RET + BPF_K, 0));
EXPECT_EQ(b, MakeInstruction(BPF_RET + BPF_K, 1));
EXPECT_EQ(b, MakeInstruction(BPF_RET + BPF_K, 1));
CodeGen::Node c =
MakeInstruction(BPF_LD + BPF_W + BPF_ABS, 0,
MakeInstruction(BPF_JMP + BPF_JSET + BPF_K, 0x100, a, b));
EXPECT_EQ(c, MakeInstruction(
BPF_LD + BPF_W + BPF_ABS, 0,
MakeInstruction(BPF_JMP + BPF_JSET + BPF_K, 0x100, a, b)));
RunTest(c);
}
TEST_F(ProgramTest, FarBranches) {
std::vector<CodeGen::Node> nodes;
nodes.push_back(MakeInstruction(BPF_RET + BPF_K, 0));
for (size_t i = 1; i < 260; ++i) {
nodes.push_back(
MakeInstruction(BPF_ALU + BPF_ADD + BPF_K, i, nodes.back()));
}
for (size_t jt = 250; jt < 260; ++jt) {
for (size_t jf = 250; jf < 260; ++jf) {
nodes.push_back(MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 0,
nodes.rbegin()[jt], nodes.rbegin()[jf]));
RunTest(nodes.back());
}
}
}
TEST_F(ProgramTest, JumpReuse) {
std::vector<CodeGen::Node> nodes;
nodes.push_back(MakeInstruction(BPF_RET + BPF_K, 0));
for (size_t i = 1; i < 260; ++i) {
nodes.push_back(
MakeInstruction(BPF_ALU + BPF_ADD + BPF_K, i, nodes.back()));
}
CodeGen::Node one =
MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 0, nodes[0], nodes[1]);
EXPECT_EQ(nodes.back() + 3, one);
RunTest(one);
CodeGen::Node two =
MakeInstruction(BPF_JMP + BPF_JEQ + BPF_K, 1, nodes[0], nodes[1]);
EXPECT_EQ(one + 1, two);
RunTest(two);
}
}
}