from download_util import (
check_sha256,
check_sha256_by_mark,
extract_compress_files_and_gen_mark,
get_local_path,
run_cmd,
import_rich_module,
)
import os
import sys
import re
import glob
import traceback
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
import requests
from dfx.build_tracker import build_tracker
class PoolDownloader:
def __init__(self, download_configs: list, global_args: object = None):
if not global_args.disable_rich:
self.progress = import_rich_module()
else:
self.progress = None
self.global_args = global_args
self.download_configs = download_configs
self.lock = threading.Lock()
self.unchanged_tool_list = []
@build_tracker(
event_name="_hwcloud_download",
build_type="prebuild"
)
def start(self) -> list:
if self.progress:
with self.progress:
self._run_download_in_thread_pool()
else:
self._run_download_in_thread_pool()
return self.unchanged_tool_list
def _run_download_in_thread_pool(self):
try:
cnt = cpu_count()
except Exception as e:
cnt = 1
with ThreadPoolExecutor(max_workers=cnt) as pool:
tasks = dict()
for config_item in self.download_configs:
task = pool.submit(self._process, config_item)
tasks[task] = os.path.basename(config_item.get("remote_url"))
self._wait_for_download_tasks_complete(tasks)
def _wait_for_download_tasks_complete(self, tasks: dict):
for task in as_completed(tasks):
try:
_ = task.result()
except Exception as e:
self._adaptive_print(f"Task {task} generated an exception: {e}", style="red")
self._adaptive_print(traceback.format_exc())
else:
self._adaptive_print(
"{}, download and decompress completed".format(tasks.get(task)),
style="green",
)
def _adaptive_print(self, msg: str, **kwargs):
if self.progress:
self.progress.console.log(msg, **kwargs)
else:
print(msg)
def _process(self, operate: dict):
global_args = self.global_args
remote_url = operate.get("remote_url")
if "python" in remote_url and global_args.glibc_version is not None:
remote_url = re.sub(r"GLIBC[0-9]\.[0-9]{2}", global_args.glibc_version, remote_url)
remote_url = global_args.tool_repo + remote_url
download_root = operate.get("download_dir")
unzip_dir = operate.get("unzip_dir")
unzip_filename = operate.get("unzip_filename")
local_path = get_local_path(download_root, remote_url)
self._adaptive_print(f"start deal {remote_url}")
mark_file_exist, mark_file_path = check_sha256_by_mark(remote_url, unzip_dir, unzip_filename)
if mark_file_exist:
self._adaptive_print(
"{}, Sha256 markword check OK.".format(remote_url), style="green"
)
with self.lock:
self.unchanged_tool_list.append(operate.get("name") + "_" + os.path.basename(remote_url))
else:
if "merged_download_config" in operate:
for item in operate["merged_download_config"]:
item_unzip_dir = item["unzip_dir"]
item_unzip_filename = item["unzip_filename"]
run_cmd(["rm", "-rf", '{}/{}'.format(item_unzip_dir, item_unzip_filename)])
else:
run_cmd(["rm", "-rf", '{}/{}'.format(unzip_dir, unzip_filename)])
run_cmd(["rm", "-rf"] + glob.glob(f"{unzip_dir}/*.{unzip_filename}.mark", recursive=False))
if os.path.exists(local_path):
check_result = check_sha256(remote_url, local_path)
if check_result:
self._adaptive_print(
"{}, Sha256 check download OK.".format(local_path),
style="green",
)
else:
os.remove(local_path)
self._try_download(remote_url, local_path)
else:
self._try_download(remote_url, local_path)
self._adaptive_print("Start decompression {}".format(local_path))
if "merged_download_config" in operate:
for item in operate["merged_download_config"]:
item_unzip_dir = item["unzip_dir"]
extract_compress_files_and_gen_mark(local_path, item_unzip_dir, mark_file_path)
else:
extract_compress_files_and_gen_mark(local_path, unzip_dir, mark_file_path)
self._adaptive_print(f"{local_path} extracted to {unzip_dir}")
def _try_download(self, remote_url: str, local_path: str):
max_retry_times = 3
download_dir = os.path.dirname(local_path)
os.makedirs(download_dir, exist_ok=True)
progress = self.progress
progress_task_id = progress.add_task(
"download", filename=os.path.basename(remote_url), start=False
) if progress else None
self._adaptive_print(f"Downloading {remote_url}")
for retry_times in range(max_retry_times):
try:
self._download_remote_file(remote_url, local_path, progress_task_id)
return
except Exception as e:
error_message = getattr(e, 'code', str(e))
self._adaptive_print(
f"Failed to open {remote_url}, Error: {error_message}",
style="red"
)
self._adaptive_print(
f"{local_path}, download failed after {max_retry_times} retries, "
"please check network status. Prebuilts download exit."
)
sys.exit(1)
def _download_remote_file(self, remote_url: str, local_path: str, progress_task_id):
buffer_size = 32768
progress = self.progress
if self.global_args.skip_ssl:
ssl_verify = False
else:
ssl_verify = True
with requests.get(remote_url, stream=True, timeout=(30, 600), verify=ssl_verify) as response:
response.raise_for_status()
total_size = int(response.headers.get("Content-Length", 0))
if progress:
progress.update(progress_task_id, total=total_size)
progress.start_task(progress_task_id)
self._save_to_local(response, local_path, buffer_size, progress_task_id)
self._adaptive_print(f"Downloaded {local_path}")
def _save_to_local(self, response: requests.Response, local_path: str, buffer_size: int, progress_task_id):
with os.fdopen(os.open(local_path, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, mode=0o640), 'wb') as dest_file:
for chunk in response.iter_content(chunk_size=buffer_size):
if chunk:
dest_file.write(chunk)
self._update_progress(progress_task_id, len(chunk))
def _update_progress(self, task_id, advance):
if self.progress:
self.progress.update(task_id, advance=advance)