/**
* BinaryEncodingVisitor.ets
*
* Binary 编码 Visitor - Protobuf wire format
*
* SwiftProtobuf 设计理念:
* "Visitor 负责所有格式的编码逻辑,消息类只需实现 traverse(),
* 无需关心编码细节。"
*
* 核心功能:
* 1. ⭐ 实现 Visitor 接口的所有 visit 方法
* 2. ⭐ 使用 Writer 编码字段为 Protobuf binary format
* 3. ⭐ 自动处理嵌套消息递归
* 4. ⭐ 内置递归深度保护(防止栈溢出攻击)
* 5. 支持所有 Protobuf 字段类型
*
* Version: 1.0.0
* ArkTS 2025 兼容
*/
import { Visitor } from './Visitor'
import { Writer } from './Writer'
import { Message } from './Message'
/**
* Binary 编码 Visitor
*
* 将消息编码为 Protobuf wire format (二进制)
*
* 使用方式:
* ```typescript
* const visitor = new BinaryEncodingVisitor()
* message.traverse(visitor)
* const binary = visitor.finish()
* ```
*
* 递归深度保护:
* - 最大嵌套深度:100 层
* - 超过限制会抛出错误
* - 防止恶意构造的深层嵌套消息导致栈溢出
*/
export class BinaryEncodingVisitor implements Visitor {
private writer: Writer
private recursionDepth: number = 0
private readonly maxRecursionDepth: number = 100
constructor(writer?: Writer) {
this.writer = writer || new Writer()
}
// ========== 标量字段 ==========
/**
* 访问 int32 字段
* Wire type: 0 (varint)
*/
visitInt32(value: number, fieldNumber: number): void {
this.writer.tag(fieldNumber, 0).int32(value)
}
/**
* 访问 int64 字段
* Wire type: 0 (varint)
*/
visitInt64(value: bigint, fieldNumber: number): void {
this.writer.tag(fieldNumber, 0).int64(value)
}
/**
* 访问 uint32 字段
* Wire type: 0 (varint)
*/
visitUint32(value: number, fieldNumber: number): void {
this.writer.tag(fieldNumber, 0).uint32(value)
}
/**
* 访问 uint64 字段
* Wire type: 0 (varint)
*/
visitUint64(value: bigint, fieldNumber: number): void {
this.writer.tag(fieldNumber, 0).int64(value)
}
/**
* 访问 sint32 字段 (zigzag 编码)
* Wire type: 0 (varint)
*/
visitSint32(value: number, fieldNumber: number): void {
this.writer.tag(fieldNumber, 0).sint32(value)
}
/**
* 访问 sint64 字段 (zigzag 编码)
* Wire type: 0 (varint)
*/
visitSint64(value: bigint, fieldNumber: number): void {
this.writer.tag(fieldNumber, 0).sint64(value)
}
/**
* 访问 fixed32 字段
* Wire type: 5 (32-bit)
*/
visitFixed32(value: number, fieldNumber: number): void {
this.writer.tag(fieldNumber, 5).fixed32(value)
}
/**
* 访问 fixed64 字段
* Wire type: 1 (64-bit)
*/
visitFixed64(value: bigint, fieldNumber: number): void {
this.writer.tag(fieldNumber, 1).fixed64(value)
}
/**
* 访问 sfixed32 字段
* Wire type: 5 (32-bit)
*/
visitSfixed32(value: number, fieldNumber: number): void {
this.writer.tag(fieldNumber, 5).sfixed32(value)
}
/**
* 访问 sfixed64 字段
* Wire type: 1 (64-bit)
*/
visitSfixed64(value: bigint, fieldNumber: number): void {
this.writer.tag(fieldNumber, 1).sfixed64(value)
}
/**
* 访问 string 字段
* Wire type: 2 (length-delimited)
*/
visitString(value: string, fieldNumber: number): void {
this.writer.tag(fieldNumber, 2).string(value)
}
/**
* 访问 bytes 字段
* Wire type: 2 (length-delimited)
*/
visitBytes(value: Uint8Array, fieldNumber: number): void {
this.writer.tag(fieldNumber, 2).bytes(value)
}
/**
* 访问 bool 字段
* Wire type: 0 (varint)
*/
visitBool(value: boolean, fieldNumber: number): void {
this.writer.tag(fieldNumber, 0).bool(value)
}
/**
* 访问 float 字段
* Wire type: 5 (32-bit)
*/
visitFloat(value: number, fieldNumber: number): void {
this.writer.tag(fieldNumber, 5).float(value)
}
/**
* 访问 double 字段
* Wire type: 1 (64-bit)
*/
visitDouble(value: number, fieldNumber: number): void {
this.writer.tag(fieldNumber, 1).double(value)
}
// ========== 枚举字段 ==========
/**
* 访问 enum 字段
* Wire type: 0 (varint)
* 枚举值以 int32 形式编码
*/
visitEnum(value: number, fieldNumber: number): void {
this.writer.tag(fieldNumber, 0).int32(value)
}
// ========== 嵌套消息(⭐ 自动递归) ==========
/**
* 访问嵌套消息字段
*
* ⭐ 核心方法:自动处理递归!
*
* 实现逻辑:
* 1. 检查递归深度(防止栈溢出)
* 2. 创建新的 BinaryEncodingVisitor
* 3. 调用嵌套消息的 traverse() 方法(自动递归)
* 4. 获取嵌套消息的二进制数据
* 5. 以 length-delimited 方式写入
*
* Wire type: 2 (length-delimited)
*
* @param value 嵌套消息对象
* @param fieldNumber proto 字段编号
* @throws 如果递归深度超过限制
*/
visitMessage(value: Message, fieldNumber: number): void {
// 递归深度保护
this.recursionDepth++
if (this.recursionDepth > this.maxRecursionDepth) {
throw new Error(`Message nesting depth exceeds limit: ${this.maxRecursionDepth}`)
}
try {
// 创建嵌套 Visitor,递归遍历
const nestedVisitor = new BinaryEncodingVisitor()
nestedVisitor.recursionDepth = this.recursionDepth // 继承递归深度
value.traverse(nestedVisitor) // ⭐ 自动递归
const bytes = nestedVisitor.finish()
// 写入 tag 和 length-delimited 数据
this.writer.tag(fieldNumber, 2).bytes(bytes)
} finally {
this.recursionDepth--
}
}
// ========== Repeated 字段 ==========
/**
* 访问 repeated int32 字段
* 每个元素单独编码
*/
visitRepeatedInt32(value: number[], fieldNumber: number): void {
for (const item of value) {
this.visitInt32(item, fieldNumber)
}
}
/**
* 访问 repeated int64 字段
*/
visitRepeatedInt64(value: bigint[], fieldNumber: number): void {
for (const item of value) {
this.visitInt64(item, fieldNumber)
}
}
/**
* 访问 repeated uint32 字段
*/
visitRepeatedUint32(value: number[], fieldNumber: number): void {
for (const item of value) {
this.visitUint32(item, fieldNumber)
}
}
/**
* 访问 repeated uint64 字段
*/
visitRepeatedUint64(value: bigint[], fieldNumber: number): void {
for (const item of value) {
this.visitUint64(item, fieldNumber)
}
}
/**
* 访问 repeated sint32 字段
*/
visitRepeatedSint32(value: number[], fieldNumber: number): void {
for (const item of value) {
this.visitSint32(item, fieldNumber)
}
}
/**
* 访问 repeated sint64 字段
*/
visitRepeatedSint64(value: bigint[], fieldNumber: number): void {
for (const item of value) {
this.visitSint64(item, fieldNumber)
}
}
/**
* 访问 repeated fixed32 字段
*/
visitRepeatedFixed32(value: number[], fieldNumber: number): void {
for (const item of value) {
this.visitFixed32(item, fieldNumber)
}
}
/**
* 访问 repeated fixed64 字段
*/
visitRepeatedFixed64(value: bigint[], fieldNumber: number): void {
for (const item of value) {
this.visitFixed64(item, fieldNumber)
}
}
/**
* 访问 repeated sfixed32 字段
*/
visitRepeatedSfixed32(value: number[], fieldNumber: number): void {
for (const item of value) {
this.visitSfixed32(item, fieldNumber)
}
}
/**
* 访问 repeated sfixed64 字段
*/
visitRepeatedSfixed64(value: bigint[], fieldNumber: number): void {
for (const item of value) {
this.visitSfixed64(item, fieldNumber)
}
}
/**
* 访问 repeated string 字段
*/
visitRepeatedString(value: string[], fieldNumber: number): void {
for (const item of value) {
this.visitString(item, fieldNumber)
}
}
/**
* 访问 repeated bytes 字段
*/
visitRepeatedBytes(value: Uint8Array[], fieldNumber: number): void {
for (const item of value) {
this.visitBytes(item, fieldNumber)
}
}
/**
* 访问 repeated bool 字段
*/
visitRepeatedBool(value: boolean[], fieldNumber: number): void {
for (const item of value) {
this.visitBool(item, fieldNumber)
}
}
/**
* 访问 repeated float 字段
*/
visitRepeatedFloat(value: number[], fieldNumber: number): void {
for (const item of value) {
this.visitFloat(item, fieldNumber)
}
}
/**
* 访问 repeated double 字段
*/
visitRepeatedDouble(value: number[], fieldNumber: number): void {
for (const item of value) {
this.visitDouble(item, fieldNumber)
}
}
/**
* 访问 repeated enum 字段
*/
visitRepeatedEnum(value: number[], fieldNumber: number): void {
for (const item of value) {
this.visitEnum(item, fieldNumber)
}
}
/**
* 访问 repeated message 字段
* ⭐ 每个消息都会递归调用 visitMessage
*/
visitRepeatedMessage(value: Message[], fieldNumber: number): void {
for (const item of value) {
try {
this.visitMessage(item, fieldNumber) // 自动递归
} catch (e) {
throw new Error(`Failed to visit repeated message at field ${fieldNumber}: ${e}`)
}
}
}
// ========== Map 字段 ==========
// Map 在 protobuf 中编码为 repeated Entry<K, V>
// Entry 是一个包含 key (field 1) 和 value (field 2) 的消息
/**
* 访问 map<string, int32> 字段
*
* 编码格式:
* repeated Entry {
* string key = 1;
* int32 value = 2;
* }
*/
visitMapStringInt32(value: Map<string, number>, fieldNumber: number): void {
value.forEach((v, k) => {
// 每个 map 条目编码为一个嵌套消息
const entryWriter = this.writer.fork()
entryWriter.tag(1, 2).string(k) // key
entryWriter.tag(2, 0).int32(v) // value
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<string, int64> 字段
*/
visitMapStringInt64(value: Map<string, bigint>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 2).string(k)
entryWriter.tag(2, 0).int64(v)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<string, string> 字段
*/
visitMapStringString(value: Map<string, string>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 2).string(k)
entryWriter.tag(2, 2).string(v)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<string, bytes> 字段
*/
visitMapStringBytes(value: Map<string, Uint8Array>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 2).string(k)
entryWriter.tag(2, 2).bytes(v)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<string, bool> 字段
*/
visitMapStringBool(value: Map<string, boolean>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 2).string(k)
entryWriter.tag(2, 0).bool(v)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<string, message> 字段
* Value 是嵌套消息,需要递归处理
*/
visitMapStringMessage(value: Map<string, Message>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 2).string(k)
// Value 是消息,递归处理
const nestedVisitor = new BinaryEncodingVisitor()
nestedVisitor.recursionDepth = this.recursionDepth + 1
if (nestedVisitor.recursionDepth > this.maxRecursionDepth) {
throw new Error(`Message nesting depth exceeds limit: ${this.maxRecursionDepth}`)
}
v.traverse(nestedVisitor)
const bytes = nestedVisitor.finish()
entryWriter.tag(2, 2).bytes(bytes)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<int32, int32> 字段
*/
visitMapInt32Int32(value: Map<number, number>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 0).int32(k)
entryWriter.tag(2, 0).int32(v)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<int32, string> 字段
*/
visitMapInt32String(value: Map<number, string>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 0).int32(k)
entryWriter.tag(2, 2).string(v)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<int32, message> 字段
*/
visitMapInt32Message(value: Map<number, Message>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 0).int32(k)
const nestedVisitor = new BinaryEncodingVisitor()
nestedVisitor.recursionDepth = this.recursionDepth + 1
if (nestedVisitor.recursionDepth > this.maxRecursionDepth) {
throw new Error(`Message nesting depth exceeds limit: ${this.maxRecursionDepth}`)
}
v.traverse(nestedVisitor)
const bytes = nestedVisitor.finish()
entryWriter.tag(2, 2).bytes(bytes)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<int64, int64> 字段
*/
visitMapInt64Int64(value: Map<bigint, bigint>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 0).int64(k)
entryWriter.tag(2, 0).int64(v)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<int64, string> 字段
*/
visitMapInt64String(value: Map<bigint, string>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 0).int64(k)
entryWriter.tag(2, 2).string(v)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
/**
* 访问 map<int64, message> 字段
*/
visitMapInt64Message(value: Map<bigint, Message>, fieldNumber: number): void {
value.forEach((v, k) => {
const entryWriter = this.writer.fork()
entryWriter.tag(1, 0).int64(k)
const nestedVisitor = new BinaryEncodingVisitor()
nestedVisitor.recursionDepth = this.recursionDepth + 1
if (nestedVisitor.recursionDepth > this.maxRecursionDepth) {
throw new Error(`Message nesting depth exceeds limit: ${this.maxRecursionDepth}`)
}
v.traverse(nestedVisitor)
const bytes = nestedVisitor.finish()
entryWriter.tag(2, 2).bytes(bytes)
this.writer.tag(fieldNumber, 2).ldelim(entryWriter)
})
}
// ========== 完成编码 ==========
/**
* 完成编码,获取最终的二进制数据
*
* @returns Protobuf wire format 二进制数据
*/
finish(): Uint8Array {
return this.writer.finish()
}
}