import ipaddress
import logging
import os
import re
import socket
from typing import Any
from urllib.parse import urlparse, urlunparse
from openjiuwen_deepsearch.common.exception import CustomValueException
from openjiuwen_deepsearch.common.status_code import StatusCode
from openjiuwen_deepsearch.common.common_constants import MAX_URL_LENGTH
logger = logging.getLogger(__name__)
SAFE_URL_SCHEMES = frozenset(['http', 'https'])
def validate_and_sanitize_url(url: str) -> str:
"""
验证URL并确保只允许安全的scheme (http/https)
Args:
url: 待验证的URL
Returns:
str: 安全的URL,若scheme不合法则返回空字符串
"""
if not url:
return ""
url_stripped = url.strip()
scheme_match = re.match(r'^([a-zA-Z][a-zA-Z0-9+.-]*)://', url_stripped)
if scheme_match:
scheme = scheme_match.group(1).lower()
if scheme not in SAFE_URL_SCHEMES:
logger.warning(
f"URL scheme '{scheme}' is not allowed, URL blocked: "
f"{url_stripped[:100]}"
)
return ""
return url_stripped
else:
logger.warning(
f"URL without valid scheme blocked: {url_stripped[:100]}"
)
return ""
def validate_url_scheme(url: str) -> tuple:
"""
验证URL scheme是否在白名单中 (http/https)
Args:
url: 待验证的URL
Returns:
tuple: (safe_url, is_valid)
- safe_url: 验证后的URL,若scheme不合法则返回空字符串
- is_valid: True表示scheme合法,False表示不合法
"""
if not url:
return "", False
url_stripped = url.strip()
scheme_match = re.match(r'^([a-zA-Z][a-zA-Z0-9+.-]*)://', url_stripped)
if scheme_match:
scheme = scheme_match.group(1).lower()
if scheme not in SAFE_URL_SCHEMES:
logger.warning(
f"URL scheme '{scheme}' is not allowed, "
f"URL blocked: {url_stripped[:100]}"
)
return "", False
return url_stripped, True
else:
logger.warning(
f"URL without valid scheme blocked: {url_stripped[:100]}"
)
return "", False
def normalize_domain(domain: str) -> str:
"""
规范化域名,处理常见的幻觉域名错误
Args:
domain: 原始域名
Returns:
规范化后的域名
"""
patterns = [
(r'\.com-([a-z]+)$', r'.com'),
(r'\.net-([a-z]+)$', r'.net'),
(r'\.org-([a-z]+)$', r'.org'),
(r'\.edu-([a-z]+)$', r'.edu'),
(r'\.gov-([a-z]+)$', r'.gov'),
(r'\.([a-z]+)-([a-z]+)$', r'.\1'),
(r'-([a-z]+)$', r''),
]
normalized = domain
for pattern, replacement in patterns:
normalized = re.sub(pattern, replacement, normalized)
return normalized
def normalize_domains(domains: Any) -> list[str]:
"""归一化域名列表."""
if not domains:
return []
if isinstance(domains, str):
domains = [domains]
if not isinstance(domains, (list, tuple, set)):
return []
normalized_domains = []
seen = set()
for domain in domains:
domain_str = str(domain).strip().lower()
parsed = urlparse(domain_str if "://" in domain_str else f"//{domain_str}")
domain_str = parsed.netloc or parsed.path
domain_str = domain_str.split("@")[-1].split(":")[0].strip(".")
if domain_str.startswith("www."):
domain_str = domain_str[4:]
if not domain_str or domain_str in seen:
continue
seen.add(domain_str)
normalized_domains.append(domain_str)
return normalized_domains
def extract_domain_from_url(url: Any) -> str:
"""从 URL 中提取域名."""
url_str = str(url or "").strip().lower()
if not url_str:
return ""
parsed = urlparse(url_str if "://" in url_str else f"//{url_str}")
domain = (parsed.netloc or "").split("@")[-1].split(":")[0].strip(".")
if domain.startswith("www."):
domain = domain[4:]
return domain
def normalize_path(path: str) -> str:
"""
规范化路径,处理路径中的错误
Args:
path: 原始路径
Returns:
规范化后的路径
"""
if len(path) > MAX_URL_LENGTH:
raise CustomValueException(
error_code=StatusCode.PARAM_CHECK_ERROR_URL_EXCEED_LENGTH.code,
message=StatusCode.PARAM_CHECK_ERROR_URL_EXCEED_LENGTH.errmsg)
path = re.sub(r'/+', '/', path)
if not path.startswith('/'):
path = '/' + path
return path
def fix_domain_path_merge(url: str) -> str:
"""
修复域名和路径被错误合并的问题
Args:
url: 原始URL
Returns:
修复后的URL
"""
if len(url) > MAX_URL_LENGTH:
raise CustomValueException(
error_code=StatusCode.PARAM_CHECK_ERROR_URL_EXCEED_LENGTH.code,
message=StatusCode.PARAM_CHECK_ERROR_URL_EXCEED_LENGTH.errmsg)
pattern = r'https?://([^/]+)-([a-z]+)/(.+)'
match = re.match(pattern, url)
if match:
domain, path_prefix, rest_path = match.groups()
return f'https://{domain}/{path_prefix}/{rest_path}'
return url
def normalize_url(url: str) -> str:
"""
规范化URL,处理幻觉产生的URL错误
Args:
url: 原始URL
Returns:
规范化后的URL
"""
try:
fixed_url = fix_domain_path_merge(url)
parsed = urlparse(fixed_url)
normalized_domain = normalize_domain(parsed.netloc)
normalized_path = normalize_path(parsed.path)
normalized_url = urlunparse((
parsed.scheme or 'https',
normalized_domain,
normalized_path,
parsed.params,
parsed.query,
parsed.fragment
))
return normalized_url
except Exception:
return url
def are_similar_urls(url1: str, url2: str, threshold: float = 0.9) -> bool:
"""
判断两个URL是否相似(可能是同一个网页的不同表示)
Args:
url1: 第一个URL
url2: 第二个URL
threshold: 相似度阈值
Returns:
是否相似的布尔值
"""
try:
norm_url1 = normalize_url(url1)
norm_url2 = normalize_url(url2)
if norm_url1 == norm_url2:
return True
parsed1 = urlparse(norm_url1)
parsed2 = urlparse(norm_url2)
domain_similar = parsed1.netloc == parsed2.netloc
path_similar = parsed1.path == parsed2.path
if domain_similar and path_similar:
return True
from difflib import SequenceMatcher
similarity = SequenceMatcher(None, norm_url1, norm_url2).ratio()
return similarity >= threshold
except Exception:
return False
def _http_service_allow_unsafe_url(env_var: str) -> bool:
value = os.environ.get(env_var, "").strip().lower()
return value in ("1", "true", "yes")
def _is_runtime_api_unsafe_url_relaxed() -> bool:
return _http_service_allow_unsafe_url("RUNTIME_API_ALLOW_UNSAFE_URL")
def _unsafe_http_service_url_exception_detail(
url: str, reason: str, service_label: str
) -> str:
return f"{service_label} is not allowed ({reason}): {url!r}"
def _validate_http_url_for_ssrf(
url: str, *, relaxed: bool, service_label: str
) -> None:
if relaxed:
return
parsed = urlparse(url)
if parsed.scheme not in ("http", "https"):
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=_unsafe_http_service_url_exception_detail(
url, "scheme must be http or https", service_label
)
),
)
host = parsed.hostname
if not host:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=_unsafe_http_service_url_exception_detail(
url, "missing host", service_label
)
),
)
host_lower = host.lower()
if host_lower == "localhost" or host_lower.endswith(".localhost"):
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=_unsafe_http_service_url_exception_detail(
url, "localhost host is blocked", service_label
)
),
)
try:
ip = ipaddress.ip_address(host)
except ValueError as host_parse_error:
try:
port = parsed.port or (443 if parsed.scheme == "https" else 80)
addr_info = socket.getaddrinfo(host, port, type=socket.SOCK_STREAM)
except (ValueError, socket.gaierror) as error:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=_unsafe_http_service_url_exception_detail(
url, f"host resolution failed: {error}", service_label
)
),
) from error
addresses = {entry[4][0] for entry in addr_info if entry[4]}
if not addresses:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=_unsafe_http_service_url_exception_detail(
url, "host resolution returned no addresses", service_label
)
),
) from host_parse_error
for address in addresses:
try:
resolved_ip = ipaddress.ip_address(address)
except ValueError as error:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=_unsafe_http_service_url_exception_detail(
url, f"invalid resolved address: {address}", service_label
)
),
) from error
is_non_public_ip = any((
resolved_ip.is_private,
resolved_ip.is_loopback,
resolved_ip.is_link_local,
resolved_ip.is_multicast,
resolved_ip.is_reserved,
resolved_ip.is_unspecified,
))
if is_non_public_ip:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=_unsafe_http_service_url_exception_detail(
url, "private or non-public IP", service_label
)
),
) from host_parse_error
return
is_non_public_ip = any((
ip.is_private,
ip.is_loopback,
ip.is_link_local,
ip.is_multicast,
ip.is_reserved,
ip.is_unspecified,
))
if is_non_public_ip:
raise CustomValueException(
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.code,
StatusCode.PARAM_CHECK_ERROR_REQUEST_PARAM_ERROR.errmsg.format(
e=_unsafe_http_service_url_exception_detail(
url, "private or non-public IP", service_label
)
),
)
def validate_runtime_request_url(url: str) -> None:
"""
Validate runtime API request URL to reduce SSRF risk.
Local debugging can bypass this check with RUNTIME_API_ALLOW_UNSAFE_URL.
"""
_validate_http_url_for_ssrf(
url,
relaxed=_is_runtime_api_unsafe_url_relaxed(),
service_label="runtime api url",
)
def validate_embedding_service_url(url: str) -> None:
"""
Validate embedding HTTP service base URL to reduce SSRF risk.
Local debugging can bypass with EMBEDDING_SERVICE_ALLOW_UNSAFE_URL=1
(same accepted values as RUNTIME_API_ALLOW_UNSAFE_URL).
"""
_validate_http_url_for_ssrf(
url,
relaxed=_http_service_allow_unsafe_url("EMBEDDING_SERVICE_ALLOW_UNSAFE_URL"),
service_label="embedding service url",
)