RFC: Transformer 重复层代表层优化
Status: Approved Author(s): yaohan404 Created: 2026-05-21 Updated: 2026-05-21
1. 概述
本文描述 TensorCast 针对 Transformer 重复层的代表层优化方案。该方案面向大语言模型中高度重复的 decoder layer 结构:模型仍保留完整层数和完整建模结果,但模型构建阶段只对结构等价的一组层中的代表层执行 patch_moe、quantize_model、shard_model 等 host-side transform,从而降低建模启动成本。
代码中已有一个历史开关名为 enable_repetition。为了避免依赖内部讨论术语,本文将该能力称为“重复层复用建模”:当多个 Transformer layer 具有相同子模块结构时,TensorCast 只真实执行第一层,并用 copy region 表达后续等价层的执行效果。本文的优化是在这个机制之上,将 host-side transform 也收敛到代表层规模。
2. 重复层复用建模机制
Transformer decoder-only 模型通常由大量结构相同的 layer 组成。例如 dense 模型可能有 64 个结构相同的 decoder layer,MoE 模型可能包含少数 dense layer 加大量结构相同的 MoE layer。
TensorCast 的重复层复用建模机制基于两个 wrapper:
RegionMarkerWrapper:包裹一组结构等价层中的第一层,也就是代表层。该层保留真实计算,并在 runtime 中标记一个可复用区域。CopyLayerWrapper:包裹后续结构等价层。forward 时通过 copy region 表达“复用代表层执行区域”的效果,而不是重复执行完整 layer 计算。
该机制的目标不是改变模型结构,而是减少性能建模过程中重复 layer 带来的执行和编译成本。外部仍应看到完整 ModuleList 层数,analytic model 也仍应统计完整层数对应的算子和权重。
优化前,该机制已经能降低 runtime/compile 中重复层的真实执行成本,但模型构建阶段仍有一个问题:CopyLayerWrapper 仍持有原始 _inner layer,默认 named_modules()、named_parameters() 会递归进入内部模块。因此后续转换逻辑仍会处理所有重复层内部的 MoE experts 和 Linear modules。
对于 DeepSeek-V3.2 这类 MoE 模型,这会放大成明显的启动耗时:
quantize_model会量化所有重复 MoE layer 内部的 experts/linearpatch_moe会遍历所有重复 MoE layer 的 expert 结构shard_model会继续处理重复层中的参数和模块
本文方案解决的是这个 host-side transform 放大问题。
3. 背景与动机
典型 DeepSeek-V3.2 text_generate 命令:
python -m cli.inference.text_generate deepseek-ai/DeepSeek-V3.2 `
--num-queries 16 `
--query-length 1 `
--context-length 4096 `
--device ATLAS_800_A3_752T_128G_DIE `
--quantize-linear-action FP8 `
--num-devices 16 `
--tp-size 1 `
--ep-size 16 `
--compile
优化前的主要问题:
- 重复层复用建模已经降低 compile/runtime 重复层执行,但没有降低 host-side transform 的重复层遍历。
quantize_model会处理所有重复层内部 linear,DeepSeek-V3.2 中会放大到数万次QuantLinearBase构建和量化。patch_moe会遍历重复 MoE layer 及 expert 结构,产生秒级额外耗时。--num-hidden-layers-override 2可以显著加速,但会改变模型结构和完整建模结果,不能作为正式优化方案。
4. 目标与非目标
目标:
- 保留完整
ModuleList层数,不改变模型外部结构。 - 保留原有 copy region runtime 语义,重复层仍由 copy region 表达。
- 后续 host-side transform 只处理代表层内部模块。
- 保持 full model weight accounting,不因隐藏 repeated layer 参数而低估模型权重。
- 验证 op table、analytic execution time、model weight、KV cache 等建模结果一致。
非目标:
- 不重写 torch.compile 或 copy region runtime 逻辑。
- 不优化 Python import/startup。
- 不重构 throughput_optimizer 搜索框架。
- 不引入 shape-only quantization 或 MoE proxy 新抽象。
5. 方案设计
5.1 代表层转换
maybe_reuse_layers() 根据 layer 子模块结构生成 key:
- 第一个结构 key 对应的 layer 包装为
RegionMarkerWrapper - 后续相同 key 的 layer 包装为
CopyLayerWrapper
本方案保持完整层列表长度,但让 CopyLayerWrapper 在模块遍历中表现为叶子节点:
CopyLayerWrapper.named_modules()只返回自身CopyLayerWrapper.modules()只返回自身CopyLayerWrapper.named_parameters()/parameters()返回空CopyLayerWrapper.named_buffers()/buffers()返回空_inner仍保留,用于 forward 语义、元信息兼容和调试
这样后续 host-side transform 会跳过 repeated copy layer 的内部结构,只处理代表层。
5.2 repeat_count 与权重统计
因为 copy layers 隐藏了内部参数,直接用 named_parameters() 统计会低估模型权重。方案在 RegionMarkerWrapper 上引入 repeat_count:
- 代表层的
repeat_count初始为 1 - 每遇到一个结构相同的 repeated layer,递增对应 representative wrapper 的
repeat_count TransformerModel.get_weight_size_nested()额外调用get_represented_extra_weight_size()- 对每个
RegionMarkerWrapper追加(repeat_count - 1) * representative_inner_weight_size
这样 host transform 只处理代表层,但 full model weight size 仍按完整层数统计。
5.3 影响范围
代码变更范围:
tensor_cast/layers/internal.pyRegionMarkerWrapper增加repeat_countCopyLayerWrapper隐藏内部 module/parameter/buffer traversal
tensor_cast/transformers/transformations.pymaybe_reuse_layers()保留完整层数,同时维护结构 key 到 representative wrapper 的映射
tensor_cast/transformers/model.py- 权重统计补充 repeated layer represented weight
tests/test_tensor_cast/test_repetition.py- 增加 representative layer 数量和
repeat_count验证
- 增加 representative layer 数量和
6. 测试与验证结果
6.1 单元测试
已验证:
python -m pytest -q tests\test_tensor_cast\test_repetition.py
结果:
4 passed
补充验证:
python -m pytest -q tests\test_tensor_cast\test_text_generate.py -k "num_hidden_layers_override or disable_repetition"
结果:
2 passed, 95 deselected
6.2 DeepSeek-V3.2 text_generate no-compile
命令:
python -m cli.inference.text_generate deepseek-ai/DeepSeek-V3.2 `
--num-queries 16 `
--query-length 1 `
--context-length 4096 `
--device ATLAS_800_A3_752T_128G_DIE `
--quantize-linear-action FP8 `
--num-devices 16 `
--tp-size 1 `
--ep-size 16
| 指标 | 优化前 | 优化后 | 变化 |
|---|---|---|---|
| profiling total | 23.638s | 5.328s | 4.44x |
model_runner_init |
17.989s | 0.749s | 24.0x |
patch_moe |
3.309s | 0.046s | 72.0x |
quantize_model |
13.741s | 0.239s | 57.5x |
shard_model |
0.481s | 0.010s | 48.1x |
QuantLinearBase.__init__ calls |
45,215 | 790 | 57.2x fewer |
QuantLinearBase.quantize_weight calls |
45,215 | 790 | 57.2x fewer |
TransformerModel._replace_module calls |
3,154 | 62 | 50.9x fewer |
| analytic execution time | 0.061924s | 0.061924s | unchanged |
| model weight size | 56.663GB | 56.663GB | unchanged |
| KV cache | 0.276GB | 0.276GB | unchanged |
6.3 DeepSeek-V3.2 text_generate compile
命令:
python -m cli.inference.text_generate deepseek-ai/DeepSeek-V3.2 `
--num-queries 16 `
--query-length 1 `
--context-length 4096 `
--device ATLAS_800_A3_752T_128G_DIE `
--quantize-linear-action FP8 `
--num-devices 16 `
--tp-size 1 `
--ep-size 16 `
--compile
| 指标 | 优化前 | 优化后 | 变化 |
|---|---|---|---|
| CLI wall time | 35.590s | 15.854s | 2.25x |
Model compilation and execution time |
8.365s | 7.146s | 1.17x |
| analytic execution time | 0.045181s | 0.045181s | unchanged |
| model weight size | 56.663GB | 56.663GB | unchanged |
| KV cache | 0.276GB | 0.276GB | unchanged |
tensor_cast.cat.default calls |
58 | 58 | unchanged |
aten.cat.default calls |
58 | 58 | unchanged |
结论:
- compile 场景的主要收益来自 compile 前 host-side transform/init,而不是 compile graph 本身。
- 优化前已有 copy region runtime 机制,compile/runtime 已经只执行 representative region。
- 本方案保持 op table 中 repeated copy ops 的统计一致性。
6.4 Dense Qwen case
命令:
python -m cli.inference.text_generate Qwen/Qwen3-32B `
--device TEST_DEVICE `
--tp-size 2 `
--num-devices 2 `
--num-queries 2 `
--query-length 1000 `
--compile `
--quantize-linear-action FP8
| 指标 | 优化前 | 优化后 | 变化 |
|---|---|---|---|
| phase total | 12.169s | 12.121s | 基本持平 |
model_runner_init |
0.606s | 0.374s | -0.232s |
quantize_model |
0.175s | 0.004s | 43.8x |
shard_model |
0.064s | 0.004s | 16.0x |
QuantLinearBase.__init__ calls |
448 | 7 | 64.0x fewer |
| analytic execution time | 0.259494s | 0.259494s | unchanged |
| model weight size | 16.725504GB | 16.725504GB | unchanged |
结论:
- 代表层优化对 dense 模型也生效,但 dense 模型 host-side transform 原始绝对耗时较小。
- 端到端耗时主要由 import/startup 和 compile/run 构成,因此用户体感不明显。
6.5 throughput_optimizer
DeepSeek-V3.2 aggregation case:
python -m cli.inference.throughput_optimizer deepseek-ai/DeepSeek-V3.2 `
--input-length 4096 `
--output-length 128 `
--device ATLAS_800_A3_752T_128G_DIE `
--quantize-linear-action FP8 `
--num-devices 16 `
--tp-sizes 1 `
--ep-sizes 16 `
--moe-dp-sizes 1 `
--batch-range 1 16 `
--tpot-limits 200 `
--max-prefill-tokens 8192 `
--reserved-memory-gb 0 `
--jobs 1
| 指标 | 优化前 | 优化后 |
|---|---|---|
| wall time | 43.727s | 21.644s |
| best throughput | 80.64 token/s | 80.64 token/s |
| TTFT | 8735.89ms | 8735.89ms |
| TPOT | 130.17ms | 130.17ms |
DeepSeek-V3.2 disagg single-config case:
python -m cli.inference.throughput_optimizer deepseek-ai/DeepSeek-V3.2 `
--device ATLAS_800_A3_752T_128G_DIE `
--num-devices 16 `
--input-length 3500 `
--output-length 1500 `
--disagg `
--ttft-limits 2000 `
--tpot-limits 100 `
--compile `
--tp-sizes 1 `
--jobs 1
| 指标 | 优化前 | 优化后 |
|---|---|---|
All experiments completed |
117.11s | 63.06s |
| outer wall time | 122.175s | 68.128s |
完整 disagg 默认搜索中,收益会被 5 个 TP 配置、Prefill/Decode 两阶段、jobs=8 多进程并发和 batch search/compile 成本稀释。该部分后续优化应在 throughput_optimizer 搜索框架中继续推进。
7. 风险与边界
7.1 权重统计风险
CopyLayerWrapper 隐藏内部参数后,必须通过 RegionMarkerWrapper.repeat_count 补偿 represented weight。测试中已覆盖 weight size 一致性,但后续如果引入新的 wrapper 或非标准 layer list,需要确认是否仍能正确统计。
7.2 模块遍历语义变化
本方案改变了 CopyLayerWrapper 的 traversal 行为。预期收益正来自该变化,但如果未来有逻辑需要检查 repeated copy layer 的 _inner 结构,需要显式访问 _inner,不能依赖 named_modules()。
7.3 非完全重复结构
当前代表层检测仍基于子模块结构 key,不处理参数值差异。该行为与原有重复层复用机制一致,本方案不扩大适用范围。
7.4 compile 收益边界
本方案不承诺 compile graph 按层数线性下降。原有 copy region runtime 已经优化 compile/runtime;本文方案主要降低 compile 前模型构建和转换成本。
8. 后续工作
本 RFC 不包含但建议后续单独推进:
- throughput_optimizer 优化:
- 合并 Prefill/Decode 同配置任务
- 复用 ProcessPoolExecutor
- 减少 batch search forward/compile 次数
- shape-only quantization:
- 本方案后
quantize_model已不再是主瓶颈,但仍可作为小模型/多 case benchmark 噪声优化。
- 本方案后
- import/startup 优化:
- lazy builtin registration 已实验,DeepSeek-V3.2 compile 主 case 重复测试收益约 0.276s,暂不作为当前主要优化方向。
9. 验收标准
代码合入前至少满足:
tests/test_tensor_cast/test_repetition.py通过tests/test_tensor_cast/test_text_generate.py -k "num_hidden_layers_override or disable_repetition"通过- DeepSeek-V3.2 text_generate compile/no-compile 的 analytic execution time、weight、KV cache 一致
- DeepSeek-V3.2 compile op table 中
tensor_cast.cat.default和aten.cat.defaultcall count 保持 58 - throughput_optimizer best result 不变