import requests
import shutil
import os
import hashlib
import platform
from typing import Optional, Dict, Tuple
from rich.progress import Progress, BarColumn, DownloadColumn, TextColumn, TimeRemainingColumn, TransferSpeedColumn
class DownloadError(Exception):
pass
def detect_platform() -> Dict[str, str]:
"""探测本机 osType 和 osArch,返回用于请求的值。"""
sys_os = platform.system().lower()
machine = platform.machine().lower()
if sys_os.startswith('windows'):
os_type = 'windows'
elif sys_os.startswith('linux'):
os_type = 'linux'
elif sys_os.startswith('darwin') or sys_os.startswith('mac'):
os_type = 'darwin'
else:
os_type = sys_os
if machine in ('amd64', 'x86_64', 'x64'):
os_arch = 'x64'
elif machine in ('arm64', 'aarch64'):
os_arch = 'arm64'
else:
os_arch = machine
return {'osType': os_type, 'osArch': os_arch}
def checksum(file_path: str, expected_checksum: Tuple[str, str]) -> bool:
"""计算文件的校验和,并可选与预期值对比。
Returns True if checksum matches or no expected_checksum provided.
Raises DownloadError on failure or mismatch.
"""
algo = expected_checksum[0].lower()
if algo not in ('sha256', 'sha1', 'md5'):
raise ValueError('Unsupported checksum algorithm: ' + algo)
h = hashlib.new(algo)
try:
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b''):
h.update(chunk)
except Exception as e:
raise DownloadError('Failed to compute checksum for {}: {}'.format(file_path, e))
computed = h.hexdigest()
checksum_value = expected_checksum[1]
if checksum_value:
if computed.lower() != checksum_value.lower():
raise DownloadError('Checksum mismatch: expected {}, got {}'.format(checksum_value, computed))
return True
def download_component(url: str, dest_path: str, expected_checksum: Optional[Tuple[str, str]] = None, chunk_size: int = 8192) -> str:
"""下载单个组件到本地路径,并可选校验 sha256 校验和。
Returns saved file path.
Raises DownloadError on failure.
"""
if os.path.exists(dest_path):
if expected_checksum:
if checksum(dest_path, expected_checksum=expected_checksum):
print('Info: 文件 {} 已存在存在,校验和匹配'.format(dest_path))
return dest_path
else:
print('Warning: existing file {} checksum mismatch, re-downloading'.format(dest_path))
else:
return dest_path
os.makedirs(os.path.dirname(os.path.abspath(dest_path)) or '.', exist_ok=True)
tmp_path = dest_path + '.part'
try:
session = requests.Session()
with session.get(url, stream=True, timeout=30) as r:
r.raise_for_status()
total = 0
try:
total_size = int(r.headers.get('Content-Length')) if r.headers.get('Content-Length') else None
except (ValueError, TypeError):
total_size = None
task_id = None
try:
rich_progress = Progress(TextColumn('{task.fields[filename]}', justify='right'), BarColumn(), DownloadColumn(), TransferSpeedColumn(), TimeRemainingColumn())
rich_progress.__enter__()
task_id = rich_progress.add_task('download', filename=os.path.basename(dest_path), total=total_size or 0)
except (ImportError, TypeError):
rich_progress = None
try:
with open(tmp_path, 'wb') as f:
for chunk in r.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
total += len(chunk)
if rich_progress and task_id is not None:
try:
rich_progress.update(task_id, advance=len(chunk))
except (KeyError, AttributeError):
pass
finally:
if rich_progress:
try:
rich_progress.__exit__(None, None, None)
except AttributeError:
pass
shutil.move(tmp_path, dest_path)
if expected_checksum:
if not checksum(dest_path, expected_checksum=expected_checksum):
raise DownloadError('Checksum mismatch: expected {}'.format(expected_checksum[1]))
return dest_path
except requests.RequestException as e:
if os.path.exists(tmp_path):
try:
os.remove(tmp_path)
except (FileNotFoundError, PermissionError):
pass
raise DownloadError('Failed to download {}: {}'.format(url, e))
except Exception:
if os.path.exists(tmp_path):
try:
os.remove(tmp_path)
except (FileNotFoundError, PermissionError):
pass
raise
def extract_archive(archive_path: str, dest_dir: str, overwrite: bool = True) -> str:
"""Extract a zip, tar, or 7z archive to dest_dir.
Supports .zip, .tar, .tar.gz, .tgz, .7z
Returns the destination directory where files were extracted.
"""
import os
import shutil
import stat
import tarfile
import py7zr
def on_rm_error(func, path, exc_info):
"""处理只读文件的删除"""
try:
os.chmod(path, stat.S_IWRITE)
func(path)
except (PermissionError, FileNotFoundError):
print('Warning: unable to extract {} to {},{}'.format(archive_path, dest_dir, exc_info))
pass
if not os.path.exists(archive_path):
raise DownloadError('未找到存档: {}'.format(archive_path))
if not overwrite and os.path.exists(dest_dir) and os.listdir(dest_dir):
print('Info: destination directory {} already exists and is not empty, skipping extraction'.format(dest_dir))
return dest_dir
else:
if os.path.exists(dest_dir):
shutil.rmtree(dest_dir, onerror=on_rm_error)
os.makedirs(dest_dir, exist_ok=True)
lower = archive_path.lower()
if lower.endswith('.zip'):
from .ziptools import extractzipfile
extractzipfile(archive_path, dest_dir, trace=None, permissions=True)
print('解压完成: {} -> {}'.format(archive_path, dest_dir))
elif lower.endswith('.tar') or lower.endswith('.tar.gz') or lower.endswith('.tgz'):
mode = 'r:gz' if (lower.endswith('.tar.gz') or lower.endswith('.tgz')) else 'r'
with tarfile.open(archive_path, mode) as t:
members = t.getmembers()
for m in members:
t.extract(m, dest_dir)
elif lower.endswith('.7z'):
with py7zr.SevenZipFile(archive_path, mode='r') as z:
z.extractall(path=dest_dir)
else:
raise DownloadError('Unsupported archive format: {}'.format(archive_path))
return dest_dir
def create_archive(source_dir: str, archive_path: str, _format: str = 'zip') -> str:
"""Create an archive (zip or tar.gz) from source_dir.
Returns the path to the created archive.
"""
import shutil
import os
if not os.path.isdir(source_dir):
raise ValueError('Source directory does not exist: {}'.format(source_dir))
_format = _format.lower()
if _format not in ('zip', 'tar.gz', 'tgz'):
raise ValueError('Unsupported archive format: {}'.format(_format))
base_name = os.path.splitext(archive_path)[0] if _format == 'zip' else os.path.splitext(os.path.splitext(archive_path)[0])[0]
print('Creating archive {} in {}'.format(base_name, source_dir))
archive_full_path = shutil.make_archive(base_name, 'zip' if _format == 'zip' else 'gztar', root_dir=source_dir)
print('Created archive: {} from {}'.format(archive_full_path, source_dir))
return archive_full_path