import os
import random
import socket
import sys
import threading
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import List, Dict
import multiprocessing
from urllib import request
from urllib.parse import urlparse
from . import logger_config
from .download_util import Color, DownloadError, DownloadUtil, is_exists, delete_if_exist, \
DownloadCheckError, get_remote_content_length, DOWNLOAD_INST, UrlOpenError, UrlOpenErrInfo
from ascend_deployer.module_utils.safe_file_handler import SafeFileHandler
LOG = logger_config.LOG
class DownloadingStatus:
def __init__(self, pkg_name: str, speed: float, percent: float, total_size: float):
self.pkg_name = pkg_name
self.speed = speed
self.percent = percent
self.total_size = total_size
def get_speed_str(self):
if self.speed < 1024:
speed_str = r" {:.2f} KB/s".format(self.speed)
else:
speed_str = r" {:.2f} MB/s".format(self.speed / 1024)
return speed_str
def get_percent_str(self):
return "{:.2f}%".format(self.percent * 100)
def get_pkg_str(self):
if len(self.pkg_name) > 50:
pkg_str = self.pkg_name[:47] + "..."
else:
pkg_str = self.pkg_name
return pkg_str
def get_progress_str(self):
progress_str_weight = round(self.percent * 30)
if self.percent == 1:
progress_str = ('=' * progress_str_weight)
else:
progress_str = ('=' * (progress_str_weight - 1) + '>').ljust(30, '-')
return progress_str
def get_downloading_print_str(self):
return Color.info("start downloading ") \
+ self.get_pkg_str().ljust(53, ' ') + ' ' \
+ self.get_percent_str().ljust(7, ' ') \
+ '[' + self.get_progress_str() + ']' \
+ self.get_speed_str().ljust(20)
def is_finished(self):
return self.percent >= 1
def is_downloading(self):
return 0 < self.percent < 1
class DownloadingBlockInfo:
def __init__(self, block_num, block_size, total_size):
self.block_num = block_num
self.block_size = block_size
self.total_size = total_size
class DownloadingStatusBuilder:
def __init__(self, pkg_name, ):
self.pkg_name = pkg_name
self.start_time = time.time()
def get_speed(self, block_info: DownloadingBlockInfo):
speed = self._get_normal_speed(block_info.block_num, block_info.block_size)
return float(speed) / 1024
def _get_normal_speed(self, block_num, block_size):
used_time = time.time() - self.start_time
if used_time == 0:
return 0
return block_num * block_size / used_time
@staticmethod
def get_percent(block_info: DownloadingBlockInfo):
if block_info.total_size == 0:
return 0
percent = block_info.block_num * block_info.block_size / block_info.total_size
return percent if percent <= 1 else 1
def build_downloading_status(self, block_info: DownloadingBlockInfo):
speed = self.get_speed(block_info)
percent = self.get_percent(block_info)
return DownloadingStatus(self.pkg_name, speed, percent, block_info.total_size)
class DownloadFileInfo:
def __init__(self, filename="", url="", sha256="", md5="", dst_file_path="", content_length=0):
self.filename = filename
self.url = url
self.sha256 = sha256
self.md5 = md5
self.dst_file_path = dst_file_path
self.content_length = content_length
if 'sha256=' in url:
self.sha256 = url.split('sha256=')[1]
elif 'md5=' in url:
self.md5 = url.split('md5=')[1]
class CalcHashResult:
MD5_TYPE = 0
SHA256_TYPE = 1
def __init__(self, file_info: DownloadFileInfo, file_hash="", is_hash_equals=False, hash_type=SHA256_TYPE):
self.file_info = file_info
self.file_hash = file_hash
self.is_hash_equals = is_hash_equals
self.hash_type = hash_type
def get_calc_sha256_result(file_info: DownloadFileInfo) -> CalcHashResult:
file_sha256 = SafeFileHandler.calculate_file_sha256(file_info.dst_file_path)
return CalcHashResult(file_info, file_sha256 or "", file_sha256 == file_info.sha256,
hash_type=CalcHashResult.SHA256_TYPE)
def get_calc_md5_result(file_info: DownloadFileInfo) -> CalcHashResult:
file_md5 = SafeFileHandler.calculate_file_md5(file_info.dst_file_path)
return CalcHashResult(file_info, file_md5 or "", file_md5 == file_info.md5, hash_type=CalcHashResult.MD5_TYPE)
def get_no_hash_result(file_info: DownloadFileInfo) -> CalcHashResult:
return CalcHashResult(file_info, "", False, hash_type=CalcHashResult.SHA256_TYPE)
class ParallelDownloader:
__thread_num = 16
if os.environ.get("ASCEND_DEPLOYER_DOWNLOAD_MAX_SIZE", "").isdigit():
env_value = int(os.environ.get("ASCEND_DEPLOYER_DOWNLOAD_MAX_SIZE"))
__thread_num = max(1, min(env_value, 256))
_MAX_DOWNLOAD_THREAD_NUM = __thread_num
_MAX_CALC_HASH_NUM = min(multiprocessing.cpu_count(), 32)
def __init__(self, file_info_list: List[DownloadFileInfo], parent_instance=None):
self._file_info_list = self._deduplicate_download_files(file_info_list)
self._lock = threading.Lock()
self._pkg_download_status_map: Dict[str, DownloadingStatus] = {}
self._need_download_files = []
self._is_download_finished = False
self._last_print_lines_num = -1
self.parent = parent_instance
@staticmethod
def _deduplicate_download_files(download_files: List[DownloadFileInfo]):
tmp_set = set()
res = []
for file_info in download_files:
dst_info = (file_info.url, file_info.dst_file_path)
if dst_info in tmp_set:
continue
tmp_set.add(dst_info)
res.append(file_info)
return res
def _update_pkg_download_status_map(self, downloading_status: DownloadingStatus):
with self._lock:
self._pkg_download_status_map[downloading_status.pkg_name] = downloading_status
def _call_schedule(self, downloading_status_builder: DownloadingStatusBuilder):
def schedule(block_num, block_size, total_size):
block_info = DownloadingBlockInfo(block_num, block_size, total_size)
status = downloading_status_builder.build_downloading_status(block_info)
self._update_pkg_download_status_map(status)
return schedule
def download(self, file_info: DownloadFileInfo):
if self._is_download_finished:
return False
parent_dir = os.path.dirname(file_info.dst_file_path)
if not os.path.exists(parent_dir):
LOG.info("mkdir : %s", os.path.basename(parent_dir))
os.makedirs(parent_dir, mode=0o750, exist_ok=True)
delete_if_exist(file_info.dst_file_path)
if self._download_with_retry(file_info):
LOG.info('download %s successfully', file_info.url)
return True
return False
def _download_with_retry(self, file_info: DownloadFileInfo, retry_times=5):
socket.setdefaulttimeout(DownloadUtil.INIT_TIMEOUT)
need_change_item = ["libtool-ltdl"]
file_name = os.path.basename(file_info.dst_file_path)
file_name_dir = os.path.basename(os.path.dirname(file_info.dst_file_path))
for item in need_change_item:
if "EulerOS" in file_info.dst_file_path and "docker" != file_name_dir and item in file_name:
file_info.dst_file_path = '/docker/'.join(file_info.dst_file_path.rsplit('/', 1))
break
error_msg = ""
for retry in range(1, retry_times + 1):
try:
LOG.info('downloading try: %s from %s', retry, file_info.url)
DownloadUtil.proxy_inst.build_proxy_handler()
status_builder = DownloadingStatusBuilder(file_info.filename)
local_file, _ = request.urlretrieve(file_info.url, file_info.dst_file_path,
self._call_schedule(status_builder))
if is_exists(local_file):
return True
except Exception as ex:
error_msg = DownloadUtil.handle_download_error(retry, retry_times, ex, file_info.url)
time.sleep(retry * 2)
LOG.error(f"Download {file_info.url} failed")
raise DownloadError(file_info.url, file_info.dst_file_path, error_msg)
@staticmethod
def _cursor_up():
sys.stdout.write('\x1b[1A')
@staticmethod
def _cursor_to_line_begin():
sys.stdout.write('\033[K')
def _sys_out_download_progress(self):
with self._lock:
self._clear_last_download_cursor()
print_items = [status for pkg_name, status in self._pkg_download_status_map.items()
if status.is_downloading()]
for item in print_items:
sys.stdout.write(item.get_downloading_print_str() + "\n")
sys.stdout.write(self._get_all_download_progress_str() + "\n")
cur_print_lines_num = len(print_items)
if self._last_print_lines_num > cur_print_lines_num:
for _ in range(self._last_print_lines_num - cur_print_lines_num + 1):
self._cursor_to_line_begin()
self._last_print_lines_num = cur_print_lines_num
sys.stdout.flush()
def _clear_last_download_cursor(self):
for _ in range(self._last_print_lines_num + 1):
self._cursor_up()
self._cursor_to_line_begin()
sys.stdout.flush()
def _get_all_download_progress_str(self):
total_download_size = sum(obj.content_length for obj in self._need_download_files if
hasattr(obj, 'content_length') and isinstance(obj.content_length, (int, float)))
downloaded_files_size = sum(
status.total_size * status.percent for pkg_name, status in self._pkg_download_status_map.items())
if total_download_size == 0:
return ""
percent = downloaded_files_size / total_download_size
if percent > 1:
percent = 1
percent_str = "{:.2f}%".format(percent * 100)
n = round(percent * 30)
s = ('=' * (n - 1) + '>').ljust(30, '-')
if percent == 1:
s = ('=' * n).ljust(30, '-')
return '\r' + Color.CLEAR + Color.info('All Download Progress:').ljust(81, ' ') + percent_str.ljust(7, ' ') \
+ '[' + s + ']'
def _await_download(self):
while not self._is_download_finished:
self._sys_out_download_progress()
time.sleep(1)
self._sys_out_download_progress()
self._clear_last_download_cursor()
def _parallel_calc_all_file_hash(self, file_info_list: List[DownloadFileInfo]) -> List[CalcHashResult]:
with multiprocessing.Pool(processes=self._MAX_CALC_HASH_NUM) as process_pool:
process_results = []
for file_info in file_info_list:
if file_info.sha256:
process_results.append(process_pool.apply_async(get_calc_sha256_result, (file_info,)))
elif file_info.md5:
process_results.append(process_pool.apply_async(get_calc_md5_result, (file_info,)))
else:
process_results.append(process_pool.apply_async(get_no_hash_result, (file_info,)))
results = [process_result.get() for process_result in process_results]
return results
def parallel_download(self, need_download_files):
print_thread = threading.Thread(target=self._await_download)
with ThreadPoolExecutor(max_workers=self._MAX_DOWNLOAD_THREAD_NUM) as thread_pool:
results = thread_pool.map(self.download, need_download_files)
print_thread.start()
self._is_download_finished = True
print_thread.join()
return results
def validate_domains(self, file_infos):
domain_groups = defaultdict(list)
for info in file_infos:
if not info.url:
continue
try:
parsed = urlparse(info.url)
if not parsed.netloc:
continue
domain = parsed.netloc
domain_groups[domain].append(info)
except Exception as e:
self.parent.open_failed_url.append(UrlOpenErrInfo("", info.url, str(e)).__dict__)
need_download_files = self.get_download_files(domain_groups)
return need_download_files
def get_download_files(self, domain_groups):
results = {}
need_download_files = []
for domain, infos in domain_groups.items():
test_urls = random.sample(infos, min(2, len(infos)))
success_count = 0
test_results = []
for info in test_urls:
try:
status_code = DOWNLOAD_INST.urlopen(info.url, read_response=False)
if 200 <= status_code < 400:
success_count += 1
test_results.append((info.url, "success"))
else:
test_results.append((info.url, f"error code: {str(status_code)}"))
except UrlOpenError as e:
test_results.append((info.url, f"url open failed with err: {str(e.err_msg)}"))
except Exception as e:
test_results.append((info.url, f"error: {str(e)}"))
if success_count > 0:
results[domain] = "partial_success"
for info in infos:
need_download_files.append(info)
else:
results[domain] = "all_failed"
for url, status in test_results:
self.parent.open_failed_url.append(UrlOpenErrInfo(domain, url, status).__dict__)
return need_download_files
@staticmethod
def get_file_size(need_download_file):
need_download_file.content_length = get_remote_content_length(need_download_file.url)
def start_download(self):
calc_all_file_hash_results = self._parallel_calc_all_file_hash(self._file_info_list)
self._need_download_files = self.validate_domains(
[calc_result.file_info for calc_result in calc_all_file_hash_results
if not calc_result.file_hash or not calc_result.is_hash_equals])
with ThreadPoolExecutor(max_workers=self._MAX_DOWNLOAD_THREAD_NUM) as thread_pool:
thread_pool.map(self.get_file_size, self._need_download_files)
results = self.parallel_download(self._need_download_files)
if not all(results):
return results
calc_download_file_hash_res = self._parallel_calc_all_file_hash(self._need_download_files)
calc_failed_res = [result for result in calc_download_file_hash_res if
result.file_info.sha256 and not result.is_hash_equals]
if calc_failed_res:
for calc_result in calc_failed_res:
LOG.error(f'The downloaded file:{calc_result.file_info.dst_file_path} ,url: {calc_result.file_info.url}'
f' file hash is not equal to the hash in config file.')
raise DownloadCheckError("\n".join(item.file_info.dst_file_path for item in calc_failed_res))
return results