import json
import math
import re
from decimal import Decimal, InvalidOperation, ROUND_HALF_UP
from typing import Tuple
from openjiuwen_deepsearch.common.common_constants import CHINESE, ENGLISH
def _has_mixed_unit_separators(unit: str) -> bool:
unit_lower = unit.lower()
return any(sep in unit for sep in ("或", "/", "|", ",", ";")) or " and " in unit_lower
def validate_visualization_extraction_schema(payload: dict) -> bool:
"""
Validate stage-1 visualization extraction output schema:
{
"image_title": str,
"image_type": "pie|line|timeline|bar",
"records": [[x: str, value_string: str, unit_string: str], ...]
}
"""
if not isinstance(payload, dict):
return False
image_title = payload.get("image_title", "")
image_type = payload.get("image_type", "")
records = payload.get("records", [])
if not (
isinstance(image_title, str)
and image_type in ("bar", "line", "pie", "timeline")
and isinstance(records, list)
):
return False
for row in records:
if not isinstance(row, list) or len(row) != 3:
return False
x, value_str = row[0], row[1]
unit_str = row[2]
if not (
isinstance(x, str) and isinstance(value_str, str) and isinstance(unit_str, str)
):
return False
if not x.strip() or not value_str.strip():
return False
if image_type != "timeline" and not unit_str:
return False
if image_type != "timeline" and _has_mixed_unit_separators(unit_str):
return False
return True
def validate_visualization_normalization_schema(
normalized_payload: dict,
image_type: str,
) -> bool:
"""
Validate stage-2 normalization output schema:
{
"unit": str,
"records": [[x: str, value: number], ...]
}
"""
if not isinstance(normalized_payload, dict) or image_type not in ("bar", "line", "pie"):
return False
unit = normalized_payload.get("unit", "")
records = normalized_payload.get("records", [])
if not isinstance(unit, str) or not unit or _has_mixed_unit_separators(unit):
return False
if not isinstance(records, list):
return False
for row in records:
if not isinstance(row, list) or len(row) != 2:
return False
x, value = row[0], row[1]
if not (
isinstance(x, str)
and isinstance(value, (int, float))
and math.isfinite(float(value))
):
return False
if image_type == "pie" and float(value) < 0:
return False
return True
class ArticlePart:
parts = ["abstract", "conclusion", "reference"]
patterns = {
"abstract": {
CHINESE: r"摘要",
ENGLISH: r"Abstract",
},
"conclusion": {
CHINESE: r"结论",
ENGLISH: r"Conclusion",
},
"reference": {
CHINESE: r"参考文章",
ENGLISH: r"References",
},
}
not_found_prompts = {
"abstract": {
CHINESE: "# 摘要\n\n[未能从生成内容中提取到摘要]",
ENGLISH: "# Abstract\n\n[No abstract could be extracted from the generated content]",
},
"conclusion": {
CHINESE: "# 结论\n\n[未能从生成内容中提取到结论]",
ENGLISH: "# Conclusion\n\n[No conclusion could be extracted from the generated content]",
},
"reference": {
CHINESE: "# 参考文章\n\n[未能从生成内容中提取到参考文章]",
ENGLISH: "# Reference Articles\n\n[No reference could be extracted from the generated content]",
},
}
titles = {
"abstract": {CHINESE: "# 摘要\n\n", ENGLISH: "# Abstract\n\n"},
"conclusion": {CHINESE: "# 结论\n\n", ENGLISH: "# Conclusion\n\n"},
"reference": {CHINESE: "# 参考文章\n\n", ENGLISH: "# Reference Articles\n\n"},
}
@classmethod
def get_not_found_prompt(cls, part, lang):
"""Get not found language prompt by language"""
return cls.not_found_prompts.get(part, {}).get(lang, "")
@classmethod
def get_title(cls, part, lang):
"""Get title by language"""
return cls.titles.get(part, {}).get(lang, "")
class MarkdownOutlineRenumber:
def __init__(self):
self.counters = {}
self.prev_level = 0
self.history = []
self.in_code_block = False
self.in_math_block = False
@staticmethod
def _parse_header(match) -> Tuple[int, str, str]:
"""parse markdown header into parts and calculate level"""
full_match = match.group(0)
outline_part = match.group(1)
level = outline_part.count("#")
return level, outline_part, full_match
def renumber_headers(self, content: str) -> str:
"""renumber subsection header number in general report"""
pattern = r"^ *(#{1,3}(?!\#)) +([0-9.]*) *"
lines = content.split("\n")
output_lines = []
for line in lines:
stripped = line.strip()
if re.match(r"^ *```.*$", line):
self.in_code_block = not self.in_code_block
output_lines.append(line)
continue
if re.match(r"^ *\$\$ *$", line):
self.in_math_block = not self.in_math_block
output_lines.append(line)
continue
if self.in_code_block or self.in_math_block:
output_lines.append(line)
continue
if line.startswith(" ") or line.startswith("\t"):
output_lines.append(line)
continue
if stripped.startswith(">"):
output_lines.append(line)
continue
new_line = re.sub(pattern, self._replace_header, line)
output_lines.append(new_line)
return "\n".join(output_lines)
def _update_counters(self, level: int):
if level < self.prev_level:
for i in range(level + 1, max(self.counters.keys(), default=0) + 1):
if i in self.counters:
self.counters[i] = 0
if level not in self.counters:
self.counters[level] = 0
self.counters[level] += 1
self.prev_level = level
def _generate_new_number(self, level: int) -> str:
new_number_parts = []
for i in range(1, level + 1):
if i in self.counters:
new_number_parts.append(str(self.counters[i]))
else:
new_number_parts.append("1")
self.counters[i] = 1
return ".".join(new_number_parts)
def _replace_header(self, match) -> str:
level, outline_part, full_match = MarkdownOutlineRenumber._parse_header(match)
self._update_counters(level)
new_number = self._generate_new_number(level)
level_1_dot = "." if level == 1 else ""
after_replace = f"{outline_part} {new_number}{level_1_dot} "
if full_match != after_replace:
self.history.append(f"from[{full_match}] -> to[{after_replace}]")
return after_replace
class XYChartMermaidGenerator:
"""
Unified generator for Mermaid `xychart-beta` charts.
It supports both "bar" and "line" and chooses the final Mermaid statement based on `image_type`.
The axis/scaling/formatting behavior follows the previous bar chart implementation.
"""
LABEL_MAX_LEN = 15
WIDTH_MIN = 360
WIDTH_MAX = 960
WIDTH_BASE = 220
WIDTH_PER_CATEGORY = 70
HEIGHT = 360
TARGET_TOP_RATIO = 1.0
NEG_TOP_RATIO = 0.95
Y_MIN_EXAGGERATION_MAX = 6.0
PAD_RATIO_TIGHT = 0.08
PAD_RATIO_LOOSE = 0.12
PAD_UP_FRACTION = 0.5
BAR_ZERO_GAP_RATIO = 2.5
HORIZONTAL_TOTAL_LABEL_LIMIT = 80.0
LABEL_WEIGHT_CJK = 2.0
LABEL_WEIGHT_ASCII = 1.0
NICE_MULTIPLIERS = (
1, 1.1, 1.2, 1.25, 1.5, 1.6, 1.8,
2, 2.2, 2.5, 3, 3.5,
4, 4.5, 5, 6, 7, 8, 9, 10, 10.5, 11, 12, 15,
)
@classmethod
def _sanitize_label(cls, label: str | None) -> str:
raw = (str(label) if label is not None else "").strip().replace('"', "'")
if not raw:
return "Item"
return raw
@classmethod
def _label_weight_length(cls, label: str) -> float:
total = 0.0
for ch in label:
if "\u4e00" <= ch <= "\u9fff":
total += cls.LABEL_WEIGHT_CJK
else:
total += cls.LABEL_WEIGHT_ASCII
return total
@classmethod
def _should_use_horizontal(cls, labels: list[str], count: int) -> bool:
if count <= 0:
return False
weights = [cls._label_weight_length(label) for label in labels]
total_len = sum(weights)
max_len = max(weights, default=0.0)
per_label_limit = cls.HORIZONTAL_TOTAL_LABEL_LIMIT / max(count, 1)
return not (
total_len <= cls.HORIZONTAL_TOTAL_LABEL_LIMIT
and max_len <= per_label_limit
)
@staticmethod
def _format_num(value: float) -> str:
"""Format numbers for Mermaid (avoid scientific notation)."""
abs_val = abs(value)
if abs_val == 0:
return "0"
if abs(value - round(value)) < 1e-6 and abs_val >= 1:
return str(int(round(value)))
if abs_val >= 1:
return f"{value:.2f}".rstrip("0").rstrip(".")
if abs_val >= 0.01:
return f"{value:.3f}".rstrip("0").rstrip(".")
decimals = max(6, int(-math.floor(math.log10(abs_val))) + 2)
return f"{value:.{decimals}f}".rstrip("0").rstrip(".")
@classmethod
def _nice_ceil(cls, value: float) -> float:
"""Ceil to a 'nice' number for y-axis max."""
if value <= 0:
return 1.0
exp = 10 ** math.floor(math.log10(value))
for m in cls.NICE_MULTIPLIERS:
candidate = m * exp
if candidate >= value:
return candidate
return 10 * exp
@classmethod
def _nice_step(cls, target_step: float) -> float:
"""Pick a 'nice' step size close to target_step."""
if target_step <= 0:
return 1.0
exp = 10 ** math.floor(math.log10(target_step))
for m in cls.NICE_MULTIPLIERS:
candidate = m * exp
if candidate >= target_step:
return candidate
return 10 * exp
@classmethod
def generate_from_json(cls, json_string: str) -> str:
"""将输入 JSON 字符串转换为目标报告片段。"""
if not json_string:
raise ValueError("empty input")
data = json.loads(json_string)
if not data or data.get("image_type") not in ("bar", "line"):
raise ValueError("input must be a bar/line chart visualization JSON")
chart_type = data.get("image_type")
raw_unit = (data.get("unit") or "").strip()
if cls._detect_mixed_unit(raw_unit):
raise ValueError("mixed units are not allowed for a single chart")
records = data.get("records", [])
if not records or len(records) < 2:
raise ValueError("records are required")
x_values: list[str] = []
raw_values: list[float] = []
for row in records:
if not isinstance(row, list) or len(row) != 2:
raise ValueError("each record must be a 2-element array")
label, value = row[0], row[1]
if not isinstance(label, str):
raise ValueError("record[0] must be a string")
if not isinstance(value, (int, float)) or not math.isfinite(float(value)):
raise ValueError("record[1] must be a finite number")
x_values.append(cls._sanitize_label(label))
raw_values.append(float(value))
display_values = raw_values
y_min, y_max = cls._compute_y_range(display_values, chart_type)
unit_title = raw_unit or "Value"
x_axis = ", ".join(f'"{c}"' for c in x_values)
series_values = ", ".join(cls._format_num(v) for v in display_values)
y_min_s = cls._format_num(y_min)
y_max_s = cls._format_num(y_max)
count = len(x_values)
width = max(
cls.WIDTH_MIN,
min(cls.WIDTH_MAX, cls.WIDTH_BASE + count * cls.WIDTH_PER_CATEGORY),
)
use_horizontal = (
chart_type == "bar" and cls._should_use_horizontal(x_values, count)
)
chart_orientation = (
"xychart-beta horizontal" if use_horizontal else "xychart-beta"
)
lines = [
"---",
"config:",
f" horizontal: {'true' if use_horizontal else 'false'}",
f" width: {width}",
f" height: {cls.HEIGHT}",
" showDataLabel: true",
" themeVariables:",
" xyChart:",
" plotColorPalette: '#7c3aed'",
"---",
chart_orientation,
f" x-axis [{x_axis}]",
f' y-axis "{unit_title}" {y_min_s} --> {y_max_s}',
f" {chart_type} [{series_values}]",
]
return "\n".join(lines)
@classmethod
def _nice_floor(cls, value: float) -> float:
if value >= 0:
return 0.0
abs_val = abs(value)
exp = 10 ** math.floor(math.log10(abs_val))
for m in cls.NICE_MULTIPLIERS:
candidate = -m * exp
if candidate <= value:
return candidate
return -10 * exp
@classmethod
def _nice_neg_ceil(cls, value: float) -> float:
"""
"Nice" ceiling for negative numbers (move toward 0).
Returns a negative number >= value.
"""
if value >= 0:
return 0.0
abs_val = abs(value)
exp = 10 ** math.floor(math.log10(abs_val))
for m in reversed(cls.NICE_MULTIPLIERS):
if m > 10:
continue
if m * exp <= abs_val + 1e-12:
return -m * exp
return -exp
@staticmethod
def _detect_mixed_unit(unit: str | None) -> bool:
if not unit:
return False
unit_lower = unit.lower()
return (
"或" in unit
or "/" in unit
or "|" in unit
or "," in unit
or ";" in unit
or " and " in unit_lower
)
@classmethod
def _compute_y_range(
cls, values: list[float], chart_type: str
) -> tuple[float, float]:
if not values:
return 0.0, 1.0
vmin = min(values)
vmax = max(values)
if vmax == 0 and vmin == 0:
return 0.0, 1.0
if vmin < 0:
if vmax > 0:
return cls._nice_floor(vmin), cls._nice_ceil(vmax)
y_min = cls._nice_floor(vmin)
y_max = cls._nice_neg_ceil(vmax * cls.NEG_TOP_RATIO)
if y_max < vmax:
y_max = vmax
if y_max <= y_min:
y_max = vmax
return y_min, y_max
vrange = vmax - vmin
def _padded_range(
min_val: float, max_val: float, force_zero: str | None
) -> tuple[float, float]:
span = max_val - min_val
if span <= 0:
span = max(abs(max_val), abs(min_val), 1.0) * 0.1
denom = max(abs(max_val), abs(min_val), 1e-9)
range_ratio = span / denom
pad_down = span * (
cls.PAD_RATIO_TIGHT if range_ratio < 0.25 else cls.PAD_RATIO_LOOSE
)
pad_up = pad_down * cls.PAD_UP_FRACTION
min_candidate = min_val - pad_down
max_candidate = max_val + pad_up
if force_zero == "min":
min_candidate = 0.0
elif force_zero == "max":
max_candidate = 0.0
span = max_candidate - min_candidate
if span <= 0:
return min_val, max_val
step = cls._nice_step(span / 6.0)
if step <= 0:
return min_val, max_val
y_min = math.floor(min_candidate / step) * step
y_max = math.ceil(max_candidate / step) * step
if force_zero == "min" and y_min < 0:
y_min = 0.0
if force_zero == "max" and y_max > 0:
y_max = 0.0
if y_min > min_val:
y_min = min_val
if y_max < max_val:
y_max = max_val
if max_val >= 0 and y_min < 0:
y_min = 0.0
if min_val <= 0 and y_max > 0 and force_zero == "max":
y_max = 0.0
return y_min, y_max
def _should_include_zero(min_val: float, max_val: float) -> bool:
if chart_type != "bar":
return False
gap_to_zero = min_val if min_val > 0 else -max_val
span = max_val - min_val
if span <= 0:
span = max(abs(max_val), abs(min_val), 1.0) * 0.1
return gap_to_zero <= span * cls.BAR_ZERO_GAP_RATIO
if vmin >= 0:
if _should_include_zero(vmin, vmax):
return _padded_range(vmin, vmax, "min")
return _padded_range(vmin, vmax, None)
if vmax <= 0:
if _should_include_zero(vmin, vmax):
return _padded_range(vmin, vmax, "max")
return _padded_range(vmin, vmax, None)
y_min = cls._nice_floor(vmin)
y_max = cls._nice_ceil(vmax)
return y_min, y_max
class PieChartMermaidGenerator:
OTHER_LABEL = "other"
EPSILON = 1e-6
@classmethod
def _sanitize_label(cls, label: str) -> str:
label = str(label).strip()
if not label:
return "label"
label = label.replace('"', "'")
return re.sub(r"\s+", " ", label).strip()
@staticmethod
def _format_num(value: float) -> str:
try:
dec_value = Decimal(str(value))
except (InvalidOperation, ValueError):
return str(value)
text = format(dec_value, "f")
if "." in text:
text = text.rstrip("0").rstrip(".")
return "0" if text in ("-0", "-0.0") else text
@staticmethod
def _format_other_value(value: float) -> str:
try:
dec_value = Decimal(str(value)).quantize(
Decimal("0.001"), rounding=ROUND_HALF_UP
)
except (InvalidOperation, ValueError):
return str(value)
text = format(dec_value, "f")
if "." in text:
text = text.rstrip("0").rstrip(".")
return "0" if text in ("-0", "-0.0") else text
@classmethod
def generate_from_json(cls, json_string: str) -> str:
if not json_string:
raise ValueError("empty input")
data = json.loads(json_string)
if not data or data.get("image_type") != "pie":
raise ValueError("input must be a pie chart visualization JSON")
unit = (data.get("unit") or "").strip()
percent_mode = bool(unit and ("%" in unit or "百分比" in unit))
records = data.get("records", [])
if not isinstance(records, list) or len(records) < 2:
raise ValueError("records are required")
labels: list[str] = []
values: list[float] = []
raw_values: list[float] = []
other_flags: list[bool] = []
for row in records:
if not isinstance(row, list) or len(row) != 2:
raise ValueError("each record must be a 2-element array")
label, value = row[0], row[1]
if not isinstance(label, str) or not label.strip():
raise ValueError("record[0] must be a non-empty string")
if not isinstance(value, (int, float)) or not math.isfinite(float(value)):
raise ValueError("record[1] must be a finite number")
value_f = float(value)
if value_f < 0:
raise ValueError("value must be non-negative")
labels.append(cls._sanitize_label(label))
values.append(value_f)
raw_values.append(value_f)
other_flags.append(False)
total = sum(values)
if percent_mode:
if total > 100.0 + cls.EPSILON:
raise ValueError("percent values sum exceeds 100")
if total < 100.0 - cls.EPSILON:
labels.append(cls.OTHER_LABEL)
other_value = 100.0 - total
values.append(other_value)
raw_values.append(other_value)
other_flags.append(True)
items = list(zip(labels, values, raw_values, other_flags))
items.sort(key=lambda item: (item[3], -item[1]))
lines = ["pie"]
for label, value, raw_value, is_other in items:
unit_suffix = unit if unit else ""
num_text = (
cls._format_other_value(raw_value)
if is_other
else cls._format_num(raw_value)
)
display_label = f"{label} ({num_text}{unit_suffix})"
value_text = (
cls._format_other_value(value) if is_other else cls._format_num(value)
)
lines.append(f' "{display_label}" : {value_text}')
return "\n".join(lines)
class TimelineChartMermaidGenerator:
"""
Output Mermaid schema:
timeline
title <title>
<time> : <event><br>...
"""
@staticmethod
def _format_event_text(text: str) -> str:
return (
str(text)
.strip()
.replace("\r\n", "\n")
.replace("\r", "\n")
.replace("\n", "<br>")
)
@classmethod
def generate_from_json(cls, json_string: str) -> str:
if not json_string:
raise ValueError("empty input")
data = json.loads(json_string)
if not data or data.get("image_type") != "timeline":
raise ValueError("input must be a timeline visualization JSON")
records = data.get("records", [])
if not isinstance(records, list) or len(records) < 1:
raise ValueError("records are required")
lines = ["timeline"]
for row in records:
if not isinstance(row, list) or len(row) != 2:
raise ValueError("each record must be a 2-element array")
t, event = row[0], row[1]
if not isinstance(t, str) or not t.strip():
raise ValueError("record[0] must be a non-empty string time")
if not isinstance(event, str) or not event.strip():
raise ValueError("record[1] must be a non-empty string event")
lines.append(f" {t.strip()} : {cls._format_event_text(event)}")
return "\n".join(lines)