import os
import re
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
from collections import defaultdict
MEGA_DP_SIZE = 4
MEGA_TP_SIZE = 4
MEGA_CP_SIZE = 2
MEGA_LD_RANKS = [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
MEGA_LAYER_NAME = "model_0.module.module.output_layer_forward"
MEGA_IPS = ["0.0.0.0", "0.0.0.0", "0.0.0.0", "0.0.0.0"]
VLLM_DP_SIZE = 2
VLLM_LD_RANKS = [[0], [4]]
VLLM_IPS = ["0.0.0.0", "0.0.0.0"]
VLLM_LAYER_NAME = "logits_processor_forward"
STEP = [3]
INTERACTION_TIME = 1
VOCAB_SIZE = 152064
K_MATCH_THRES = 0.3
def draw_train_infer_max_token_logp(vllm_probs, mega_probs, current_step, mega_id):
m1 = vllm_probs.shape[0]
m2 = mega_probs.shape[0]
alpha = 0.1
colors_A = plt.cm.tab20c.colors[:4]
colors_B = plt.cm.tab20c.colors[4:8]
def lighten_color(color, amount=0.3):
"""
将 RGB 颜色变亮(向白色靠近)
color: (r, g, b, a) 或 (r, g, b)
amount: 0~1,越大越亮
"""
import matplotlib.colors as mc
c = np.array(mc.to_rgb(color))
lighter = c + (1 - c) * amount
return np.clip(lighter, 0, 1)
plt.figure(figsize=(12, 7))
for i in range(m1):
orig = vllm_probs[i]
smoothed = pd.Series(orig).ewm(alpha=alpha).mean().values
c = colors_A[i]
c_light = lighten_color(c, amount=0.4)
plt.plot(orig, color=c, linewidth=1.5, label=f'vLLM-{i + 1}')
plt.plot(smoothed, color=c_light, linestyle='--', linewidth=1.5, label=f'vLLM-{i + 1} (smooth)')
for j in range(m2):
orig = mega_probs[j]
smoothed = pd.Series(orig).ewm(alpha=alpha).mean().values
c = colors_B[j]
c_light = lighten_color(c, amount=0.4)
plt.plot(orig, color=c, linewidth=1.5, label=f'Mega-{j + 1}')
plt.plot(smoothed, color=c_light, linestyle='--', linewidth=1.5, label=f'Mega-{j + 1} (smooth)')
plt.title('Log p of vLLM and Mega')
plt.xlabel('Token')
plt.ylabel('log p')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig(f'logp_step-{current_step}_traj-{mega_id}.png', dpi=300, bbox_inches='tight')
plt.close()
def draw_train_infer_logits_diff(vllm_probs, mega_probs, current_step):
diffs = np.abs(vllm_probs - mega_probs)
corr = np.corrcoef(vllm_probs, mega_probs)[0, 1]
plt.figure(figsize=(8, 6))
scatter = plt.scatter(
vllm_probs,
mega_probs,
c=diffs,
cmap='viridis_r',
alpha=0.7,
edgecolors='none',
)
plt.xlim(0.0, 1.0)
plt.ylim(0.0, 1.0)
plt.plot([0, 1], [0, 1], color='red', linestyle='--', linewidth=1, label='y = x')
plt.legend()
plt.text(
0.75,
0.05,
f'Pearson r = {corr:.3f}',
transform=plt.gca().transAxes,
fontsize=12,
bbox=dict(boxstyle='round,pad=0.3', facecolor='white', alpha=0.8),
)
cbar = plt.colorbar(scatter, ticks=[0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
cbar.set_label('Absolute Difference (|Infer - Train|)', rotation=270, labelpad=20)
plt.xlabel('Inference Probability')
plt.ylabel('Training Probability')
plt.title('Token-wise Probability Comparison')
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()
plt.savefig(f'prob_difference_step-{current_step}.png', dpi=300, bbox_inches='tight')
plt.close()
return corr
def lcs(s1, s2):
"""
s1: predicted tokens in training
s2: predicted tokens in inference
This is a non-classic LCS problem. Without tool-use, every token in s2 should be able to match continuously with
tokens in s1. Even if several tokens do not match, it is still very possible that the next tokens in both sequences
will match. With tool-use, it is possible for several consecutive tokens in s1 to fail to match with s2, but then
starting from a certain token, the sequence can continue to match with s2 consecutively.
"""
num_map = [[0 for _ in range(len(s2) + 1)] for _ in range(len(s1) + 1)]
route_map = [[(0, 0) for _ in range(len(s2) + 1)] for _ in range(len(s1) + 1)]
for i in range(1, len(s1) + 1):
for j in range(1, len(s2) + 1):
if s1[i - 1] == s2[j - 1]:
num_map[i][j] = num_map[i - 1][j - 1] + 1
route_map[i][j] = (i - 1, j - 1)
else:
if num_map[i][j - 1] >= num_map[i][j]:
num_map[i][j] = num_map[i][j - 1]
route_map[i][j] = (i, j - 1)
if num_map[i - 1][j] >= num_map[i][j]:
num_map[i][j] = num_map[i - 1][j]
route_map[i][j] = (i - 1, j)
if num_map[i - 1][j - 1] >= num_map[i][j]:
num_map[i][j] = num_map[i - 1][j - 1]
route_map[i][j] = (i - 1, j - 1)
i, j = len(s1), len(s2)
index_record = [(i, j)]
while i > 0 and j > 0:
newi, newj = route_map[i][j]
if index_record[-1][1] == newj:
index_record[-1] = (newi, newj)
else:
index_record.append((newi, newj))
i, j = newi, newj
return num_map[-1][-1], index_record[:-1]
def lcs_match(mega_pred, candidate_reqs, merged_vllm_index_dict, left_traj):
max_match_num = -1
max_req_id = ""
max_index = []
for req_id in candidate_reqs:
vllm_pred = merged_vllm_index_dict[req_id][..., 0]
match_num, index_record = lcs(mega_pred.cpu().numpy(), vllm_pred.cpu().numpy())
if match_num > max_match_num:
max_match_num = match_num
max_req_id = req_id
max_index = index_record
elif match_num == max_match_num and req_id in left_traj:
max_req_id = req_id
max_index = index_record
mega_index = []
vllm_index = []
for idx in max_index[::-1]:
mega_index.append(idx[0] - 1)
vllm_index.append(idx[1] - 1)
mega_index = np.array(mega_index)
vllm_index = np.array(vllm_index)
return max_req_id, max_match_num, mega_index, vllm_index
def single_trajectory_match(divided_vllm_index, mega_pred):
next_st = 0
st_list = []
current_match_num = 0
vllm_total_len = 0
for inter_id in range(INTERACTION_TIME):
if divided_vllm_index[inter_id].nelement() == 0:
break
sub_max_match_num = -1
st = next_st
vllm_pred = divided_vllm_index[inter_id][..., 0]
while st < mega_pred.shape[0]:
valid_length = min(vllm_pred.shape[0], mega_pred.shape[0] - st)
match_num = (vllm_pred[:valid_length] == mega_pred[st : st + valid_length]).int().sum().item()
if match_num > sub_max_match_num:
sub_max_match_num = match_num
next_st = st
st += 1
current_match_num += sub_max_match_num
st_list.append(next_st)
next_st += vllm_pred.shape[0]
vllm_total_len += vllm_pred.shape[0]
return current_match_num, st_list, vllm_total_len
class MegaSegTree:
def __init__(self, n):
self.n = n
self.has_one = [False] * (4 * n)
self.lazy = [None] * (4 * n)
def _push(self, node, start, end):
if self.lazy[node] is not None:
color = self.lazy[node]
self.has_one[node] = color == 1
if start != end:
self.lazy[node * 2] = color
self.lazy[node * 2 + 1] = color
self.lazy[node] = None
def _update(self, node, start, end, l, r, color):
self._push(node, start, end)
if r < start or end < l:
return
if l <= start and end <= r:
self.lazy[node] = color
self._push(node, start, end)
return
mid = (start + end) // 2
self._update(node * 2, start, mid, l, r, color)
self._update(node * 2 + 1, mid + 1, end, l, r, color)
self._push(node * 2, start, mid)
self._push(node * 2 + 1, mid + 1, end)
self.has_one[node] = self.has_one[node * 2] or self.has_one[node * 2 + 1]
def update(self, l, r, color):
"""将 [l, r] 染成 color (0 或 1)"""
self._update(1, 0, self.n - 1, l, r, color)
def _query(self, node, start, end, l, r):
if r < start or end < l:
return False
self._push(node, start, end)
if l <= start and end <= r:
return self.has_one[node]
mid = (start + end) // 2
left_has = self._query(node * 2, start, mid, l, r)
right_has = self._query(node * 2 + 1, mid + 1, end, l, r)
return left_has or right_has
def query_has_one(self, l, r):
"""返回 [l, r] 中是否存在 1"""
return self._query(1, 0, self.n - 1, l, r)
def greedy_match(edges, divided_vllm_index_dict, merged_mega_index_list):
matched_traj = set()
matched_idx = [MegaSegTree(merged_mega_index_list[i].shape[0]) for i in range(len(merged_mega_index_list))]
match_map = []
for w, traj_id, idx, match_num, st_list in edges:
for i in range(len(st_list)):
mega_index = []
vllm_index = []
vllm_st = 0
mega_pred = merged_mega_index_list[idx][..., 0]
vllm_sub_seq = divided_vllm_index_dict[traj_id][i].shape[0]
if traj_id not in matched_traj and not matched_idx[idx].query_has_one(
st_list[i], min(st_list[i] + vllm_sub_seq, mega_pred.shape[0])
):
new_mega_index = list(range(st_list[i], min(st_list[i] + vllm_sub_seq, mega_pred.shape[0])))
mega_index.extend(new_mega_index)
new_vllm_index = list(range(vllm_st, vllm_st + len(new_mega_index)))
vllm_index.extend(new_vllm_index)
vllm_st += len(new_vllm_index)
mega_index = np.array(mega_index)
vllm_index = np.array(vllm_index)
match_map.append(
{
"mega_id": idx,
"req_id": traj_id,
"match_num": match_num,
"mega_index": mega_index,
"vllm_index": vllm_index,
}
)
matched_traj.add(traj_id)
matched_idx[idx].update(st_list[i], min(st_list[i] + vllm_sub_seq, mega_pred.shape[0]) - 1, 1)
if len(match_map) == len(divided_vllm_index_dict.keys()):
break
return match_map
def k_gram_match(mega_index_list, vllm_index_dict, k=5, base=256, mod=2**61 - 1):
"""Quickly match megatron prediction and vllm prediction, provide approximated match map."""
def hash_list(arr):
h = 0
for i in range(len(arr)):
h = (h * base + arr[i]) % mod
return h
inverted_index = defaultdict(set)
power = pow(base, k - 1, mod)
for idx, mega_index in enumerate(mega_index_list):
mega_pred = mega_index[..., 0].squeeze().cpu().numpy()
h = hash_list(mega_pred[:k])
inverted_index[h].add(idx)
for i in range(1, mega_index.shape[0] - k + 1):
h = (h - mega_pred[i - 1] * power) % mod
h = (h * base + mega_pred[i + k - 1]) % mod
inverted_index[h].add(idx)
candidate_map = {}
left_idx = list(range(len(mega_index_list)))
for traj_id, vllm_index in vllm_index_dict.items():
vote_count = defaultdict(int)
candidate_map[traj_id] = []
vllm_pred = vllm_index[..., 0].cpu().numpy()
h = hash_list(vllm_pred[:k])
for idx in inverted_index[h]:
vote_count[idx] += 1
for i in range(1, vllm_index.shape[0] - k + 1):
h = (h - vllm_pred[i - 1] * power) % mod
h = (h * base + vllm_pred[i + k - 1]) % mod
for idx in inverted_index[h]:
vote_count[idx] += 1
sorted_count = sorted(vote_count.items(), key=lambda x: x[1], reverse=True)
for idx, v in sorted_count:
if v / (vllm_index.shape[0] - k + 1) > K_MATCH_THRES:
candidate_map[traj_id].append(idx)
if idx in left_idx:
left_idx.remove(idx)
return candidate_map
def find_inference_step_range(train_step, last_train_step=-1):
train_mtime = 0
for tp_rank in MEGA_LD_RANKS:
for rank in tp_rank:
train_file_path = f"./mathtool_actor_sentinel_dump_data/step_{train_step}/rank_{rank}/tensor_data.pkl"
train_mtime = max(train_mtime, os.path.getmtime(train_file_path))
last_train_mtime = 0
if last_train_step != -1:
for tp_rank in MEGA_LD_RANKS:
for rank in tp_rank:
train_file_path = (
f"./mathtool_actor_sentinel_dump_data/step_{last_train_step}/rank_{rank}/tensor_data.pkl"
)
last_train_mtime = max(last_train_mtime, os.path.getmtime(train_file_path))
else:
last_train_mtime = train_mtime - 3600 * 4
infer_step_folders = "./mathtool_rollout_sentinel_dump_data"
inference_steps = []
for item in os.listdir(infer_step_folders):
item_path = os.path.join(infer_step_folders, item)
if os.path.isdir(item_path) and "step" in item:
vllm_step = int(re.search("step_(\d+)", item)[1])
for tp_rank in VLLM_LD_RANKS:
for rank in tp_rank:
infer_file_path = os.path.join(item_path, f"rank_{rank}", "tensor_data.pkl")
if os.path.exists(infer_file_path):
infer_mtime = os.path.getmtime(infer_file_path)
if (
infer_mtime < train_mtime
and infer_mtime > last_train_mtime
and vllm_step not in inference_steps
):
inference_steps.append(vllm_step)
inference_steps = sorted(inference_steps)
return inference_steps
def load_mega_logits(step):
merged_mega_logits_list = []
merged_mega_index_list = []
merged_mega_prob_list = []
divided_mega_logits_list = []
divided_mega_index_list = []
divided_mega_prob_list = []
partial_vocab = VOCAB_SIZE // MEGA_TP_SIZE
for tp_range, node_ip in zip(MEGA_LD_RANKS, MEGA_IPS):
tmp_traj_num = -1
mega_logits_list = []
for rank in tp_range:
mega_path = f"./mathtool_actor_sentinel_dump_data/step_{step}/{node_ip}-rank_{rank}/tensor_data.pkl"
mega_logits = torch.load(mega_path)
if tmp_traj_num == -1:
tmp_traj_num = len(mega_logits["top100_logits"][MEGA_LAYER_NAME])
else:
expected_len = len(mega_logits["top100_logits"][MEGA_LAYER_NAME])
if tmp_traj_num != expected_len:
raise ValueError(
f"Data length mismatch: tmp_traj_num ({tmp_traj_num}) does not match "
f"expected length ({expected_len})"
)
mega_logits_list.append(mega_logits)
tmp_mega_logits_list = []
tmp_mega_index_list = []
tmp_mega_prob_list = []
for i in range(tmp_traj_num):
logits_list = [x["top100_logits"][MEGA_LAYER_NAME][i].npu() for x in mega_logits_list]
mb = logits_list[0].shape[2]
for j in range(mb):
ori_logits_list = []
ori_index_list = []
for k, logit in enumerate(logits_list):
ori_logit = logit[0, :, j, :]
index = logit[1, :, j, :].long()
index += k * partial_vocab
ori_logits_list.append(ori_logit)
ori_index_list.append(index)
mega_logits = torch.cat(ori_logits_list, dim=-1)
mega_index = torch.cat(ori_index_list, dim=-1)
top_mega_logits, top_index = mega_logits.sort(descending=True)
top_mega_logits = top_mega_logits[..., :100]
mega_index = torch.gather(mega_index, dim=-1, index=top_index[..., :100])
tmp_mega_logits_list.append(top_mega_logits)
tmp_mega_index_list.append(mega_index)
safe_exp_logits = torch.exp(
top_mega_logits - torch.amax(top_mega_logits, dim=-1, keepdim=True)
)
tail_exp = (VOCAB_SIZE - 100) * safe_exp_logits[..., -1]
prob = safe_exp_logits / (torch.sum(safe_exp_logits, dim=-1) + tail_exp).unsqueeze(-1)
tmp_mega_prob_list.append(prob)
divided_mega_logits_list.append(tmp_mega_logits_list)
divided_mega_index_list.append(tmp_mega_index_list)
divided_mega_prob_list.append(tmp_mega_prob_list)
if len(divided_mega_logits_list) == MEGA_CP_SIZE:
cp_traj_num = len(divided_mega_logits_list[0])
for j in range(cp_traj_num):
logits = [x[j] for x in divided_mega_logits_list]
index = [x[j] for x in divided_mega_index_list]
prob = [x[j] for x in divided_mega_prob_list]
merged_mega_logits_list.append(torch.cat(logits, dim=0))
merged_mega_index_list.append(torch.cat(index, dim=0))
merged_mega_prob_list.append(torch.cat(prob, dim=0))
divided_mega_logits_list = []
divided_mega_index_list = []
divided_mega_prob_list = []
traj_num = len(merged_mega_logits_list)
print("Mega trajectory number:", traj_num)
return merged_mega_logits_list, merged_mega_index_list, merged_mega_prob_list
def load_vllm_logits(step):
vllm_logits_dict = {}
if step > 0:
inference_steps = [step - 1, step]
else:
inference_steps = [step]
for tp_range, node_ip in zip(VLLM_LD_RANKS, VLLM_IPS):
rank = tp_range[0]
for i in inference_steps:
vllm_path = f"./mathtool_rollout_sentinel_dump_data/step_{i}/{node_ip}-rank_{rank}/tensor_data.pkl"
if not os.path.exists(vllm_path):
continue
vllm_logits = torch.load(vllm_path)
if "input_id" not in vllm_logits.keys():
raise RuntimeError("Not found vLLM req_id information")
for l in range(len(vllm_logits["input_id"][VLLM_LAYER_NAME])):
if len(vllm_logits["input_id"][VLLM_LAYER_NAME][l]) > 0:
req_id = vllm_logits["input_id"][VLLM_LAYER_NAME][l][0]
train_step = int(re.search(".*-(trainstep)-(\d+)", req_id)[2])
if 0 <= train_step - step < -2:
raise RuntimeError(
f"Train step in {vllm_path} and target step are mismatched. "
f"Got {train_step} and target is {step}"
)
for req_id in vllm_logits["input_id"][VLLM_LAYER_NAME][l]:
if req_id not in vllm_logits_dict.keys():
vllm_logits_dict[req_id] = []
len_input_id = len(vllm_logits["input_id"][VLLM_LAYER_NAME][l])
len_top100 = vllm_logits["top100_logits"][VLLM_LAYER_NAME][l].shape[1]
if len_input_id != len_top100:
raise ValueError(
f"Dimension mismatch at layer {l}: input_id length ({len_input_id}) "
f"does not match top100_logits shape ({len_top100})"
)
for j in range(len(vllm_logits["input_id"][VLLM_LAYER_NAME][l])):
req_id = vllm_logits["input_id"][VLLM_LAYER_NAME][l][j]
logits = vllm_logits["top100_logits"][VLLM_LAYER_NAME][l][:, j, :]
vllm_logits_dict[req_id].append(logits.unsqueeze(1).npu())
for req_id in vllm_logits_dict.keys():
vllm_logits_dict[req_id] = torch.cat(vllm_logits_dict[req_id], dim=1)
merged_vllm_logits_dict = {}
merged_vllm_index_dict = {}
merged_vllm_prob_dict = {}
divided_vllm_logits_dict = {}
divided_vllm_index_dict = {}
inter_time = 0
for req_id, v in vllm_logits_dict.items():
req_id_seg = req_id.split('-')
traj_id = '-'.join(req_id_seg[:-2])
divided_vllm_logits_dict[traj_id] = [v]
divided_vllm_index_dict[traj_id] = [v]
for traj_id in divided_vllm_logits_dict.keys():
merged_vllm_logits_dict[traj_id] = torch.cat(divided_vllm_logits_dict[traj_id], dim=1)
merged_vllm_index_dict[traj_id] = merged_vllm_logits_dict[traj_id][1].long()
merged_vllm_logits_dict[traj_id] = merged_vllm_logits_dict[traj_id][0]
for i in range(len(divided_vllm_logits_dict[traj_id])):
if divided_vllm_logits_dict[traj_id][i].nelement() > 0:
divided_vllm_index_dict[traj_id][i] = divided_vllm_logits_dict[traj_id][i][1].long()
divided_vllm_logits_dict[traj_id][i] = divided_vllm_logits_dict[traj_id][i][0]
logit = merged_vllm_logits_dict[traj_id]
safe_exp_logits = torch.exp(logit - torch.amax(logit, dim=-1, keepdim=True))
tail_exp = (VOCAB_SIZE - 100) * safe_exp_logits[..., -1]
prob = safe_exp_logits / (torch.sum(safe_exp_logits, dim=-1) + tail_exp).unsqueeze(-1)
merged_vllm_prob_dict[traj_id] = prob
print("vllm trajectory number:", len(merged_vllm_logits_dict.keys()))
return (
merged_vllm_logits_dict,
merged_vllm_index_dict,
merged_vllm_prob_dict,
divided_vllm_logits_dict,
divided_vllm_index_dict,
)
@torch.no_grad()
def top_logits_match(step):
merged_mega_logits_list, merged_mega_index_list, merged_mega_prob_list = load_mega_logits(step)
(
merged_vllm_logits_dict,
merged_vllm_index_dict,
merged_vllm_prob_dict,
divided_vllm_logits_dict,
divided_vllm_index_dict,
) = load_vllm_logits(step)
if os.path.exists(f"./outputs/rollout_actor_match_step_{step}.pkl"):
with open(f"./outputs/rollout_actor_match_step_{step}.pkl", "rb") as f:
match_map = pickle.load(f)
else:
candidate_map = k_gram_match(merged_mega_index_list, merged_vllm_index_dict)
edges = []
for traj_id in candidate_map:
for idx in candidate_map[traj_id]:
current_match_num, st_list, vllm_total_length = single_trajectory_match(
divided_vllm_index_dict[traj_id], merged_mega_index_list[idx][..., 0].squeeze()
)
edges.append((float(current_match_num) / vllm_total_length, traj_id, idx, current_match_num, st_list))
edges.sort(key=lambda x: x[0], reverse=True)
match_map = greedy_match(edges, divided_vllm_index_dict, merged_mega_prob_list)
with open(f"./outputs/rollout_actor_match_step_{step}.pkl", 'wb') as f:
pickle.dump(match_map, f)
traj_num = len(match_map)
match_ratio = 0.0
target_draw_traj_id = []
for i in range(len(match_map)):
print(
match_map[i]['mega_id'],
match_map[i]['req_id'],
match_map[i]['match_num'],
match_map[i]['match_num'] / len(match_map[i]['vllm_index']),
match_map[i]["mega_index"][0],
match_map[i]["mega_index"][-1],
)
match_ratio += match_map[i]['match_num'] / len(match_map[i]['vllm_index'])
if match_map[i]['mega_id'] not in target_draw_traj_id:
target_draw_traj_id.append(match_map[i]['mega_id'])
match_ratio /= len(match_map)
print(len(match_map), len(target_draw_traj_id), f"Matched idx {target_draw_traj_id}")
print(f"Average match ratio: {match_ratio}")
vllm_max_token_probs = [[] for _ in range(traj_num)]
mega_max_token_probs = [[] for _ in range(traj_num)]
vllm_probs = []
mega_probs = []
all_kl = []
complete_mega_probs = [np.full(merged_mega_prob_list[i].shape[0], -np.inf) for i in target_draw_traj_id]
complete_vllm_probs = [np.full(merged_mega_prob_list[i].shape[0], -np.inf) for i in target_draw_traj_id]
for i in range(traj_num):
mega_id = match_map[i]["mega_id"]
req_id = match_map[i]["req_id"]
mega_index = match_map[i]['mega_index']
vllm_index = match_map[i]['vllm_index']
mega_p = merged_mega_prob_list[mega_id][mega_index].cpu()
vllm_p = merged_vllm_prob_dict[req_id][vllm_index].cpu()
mega_pred = merged_mega_index_list[mega_id][mega_index].cpu()
vllm_pred = merged_vllm_index_dict[req_id][vllm_index].cpu()
mega_prob_tensor = torch.zeros((vllm_p.shape[0], VOCAB_SIZE))
vllm_prob_tensor = torch.zeros((vllm_p.shape[0], VOCAB_SIZE))
mega_prob_tensor[torch.arange(mega_prob_tensor.shape[0]).view(-1, 1), mega_pred] = mega_p
vllm_prob_tensor[torch.arange(vllm_prob_tensor.shape[0]).view(-1, 1), vllm_pred] = vllm_p
valid_prob_index = (mega_prob_tensor > 1e-2) & (vllm_prob_tensor > 1e-2)
mega_probs.extend(mega_prob_tensor[valid_prob_index].cpu().tolist())
vllm_probs.extend(vllm_prob_tensor[valid_prob_index].cpu().tolist())
mega_max_probs = mega_prob_tensor[
torch.arange(mega_prob_tensor.shape[0]).view(-1, 1), vllm_pred[:, 0].view(-1, 1)
].squeeze()
vllm_max_probs = vllm_prob_tensor[
torch.arange(vllm_prob_tensor.shape[0]).view(-1, 1), vllm_pred[:, 0].view(-1, 1)
].squeeze()
if mega_id in target_draw_traj_id:
idx = target_draw_traj_id.index(mega_id)
mega_prob = np.log(mega_max_probs.numpy())
vllm_prob = np.log(vllm_max_probs.numpy())
complete_mega_probs[idx][mega_index] = mega_prob
complete_vllm_probs[idx][mega_index] = vllm_prob
mega_prob_tensor = mega_prob_tensor.clamp(min=1e-20)
vllm_prob_tensor = vllm_prob_tensor.clamp(min=1e-20)
kl = (vllm_prob_tensor * (torch.log(vllm_prob_tensor) - torch.log(mega_prob_tensor))).sum(dim=1)
all_kl.append(kl)
for i in target_draw_traj_id:
idx = target_draw_traj_id.index(i)
draw_train_infer_max_token_logp(
np.expand_dims(complete_vllm_probs[idx], 0), np.expand_dims(complete_mega_probs[idx], 0), current_step, i
)
all_kl = torch.cat(all_kl[:], dim=0)
avg_kl_div = all_kl.mean().item()
print(f"Average KL divergence: {avg_kl_div}")
vllm_probs = np.array(vllm_probs)
mega_probs = np.array(mega_probs)
corr = draw_train_infer_logits_diff(vllm_probs, mega_probs, current_step)
print(f"Pearson correlation: {corr}")
def router_index_compare():
mega_path = "./sentinel_dump_data_router_index/step_0/rank_0/tensor_data.pkl"
vllm_path = "./vllm_sentinel_dump_data_router_index/step_0/rank_0/tensor_data.pkl"
mega_index = torch.load(mega_path)['raw_data']
vllm_logits = torch.load(vllm_path)['raw_data']
layers = 48
mega_pre_id = 5
k = 8
for i in range(layers):
mega_layer_index = mega_index[f'module.decoder.layers.{i}.mlp.router_forward'][0][:mega_pre_id]
vllm_layer_logits = vllm_logits[f'model.layers.{i}.mlp.gate_forward'][0]
vllm_layer_index = vllm_layer_logits.sort(descending=True, dim=-1)[1][:, :8]
print(i, mega_layer_index, vllm_layer_index)
if __name__ == '__main__':
for current_step in STEP:
top_logits_match(current_step)