"""CMS shared utilities."""
import os
import sys
import re
import subprocess
import signal
import time
import glob as glob_mod
from log_config import get_logger
LOGGER = get_logger()
_ACTION_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _ACTION_ROOT not in sys.path:
sys.path.append(_ACTION_ROOT)
from nofile_utils import (
apply_nofile_rlimit_before_setuid,
resolve_nofile_rlimit_for_user,
)
class CommandError(Exception):
"""Exception when command execution fails."""
def __init__(self, cmd, returncode, stdout="", stderr=""):
self.cmd = cmd
self.returncode = returncode
self.stdout = stdout
self.stderr = stderr
super().__init__(
f"Command failed (rc={returncode}): {cmd}\n"
f"stdout: {stdout}\nstderr: {stderr}"
)
def exec_popen(cmd, timeout=1800):
"""Execute shell command via subprocess, return (returncode, stdout, stderr)."""
pobj = subprocess.Popen(
["bash"],
shell=False,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
try:
stdout_bytes, stderr_bytes = pobj.communicate(
input=(cmd + os.linesep).encode(),
timeout=timeout,
)
except subprocess.TimeoutExpired:
pobj.kill()
pobj.communicate()
return -1, "Time Out.", f"Command timed out after {timeout}s"
stdout = stdout_bytes.decode().rstrip(os.linesep)
stderr = stderr_bytes.decode().rstrip(os.linesep)
return pobj.returncode, stdout, stderr
def run_cmd(cmd, error_msg="Command failed", force_uninstall=None):
"""Execute command and check return code; raise CommandError on failure unless force_uninstall=='force'."""
ret_code, stdout, stderr = exec_popen(cmd)
if ret_code:
output = stdout + stderr
LOGGER.error("%s.\ncommand: %s.\noutput: %s" % (error_msg, cmd, output))
if force_uninstall != "force":
raise CommandError(cmd, ret_code, stdout, stderr)
return stdout
def run_as_user(cmd, user, log_file=None):
"""Execute shell command as specified user. Returns (returncode, stdout, stderr)."""
if log_file:
full_cmd = f'su -s /bin/bash - {user} -c "{cmd} >> {log_file} 2>&1"'
else:
full_cmd = f'su -s /bin/bash - {user} -c "{cmd}"'
return exec_popen(full_cmd)
def _tail_log(log_file, n=20):
"""Read last n lines from a log file for error summary."""
try:
with open(log_file, "r", encoding="utf-8", errors="replace") as f:
lines = f.readlines()
tail = "".join(lines[-n:]).strip()
if tail:
return f"{tail}\nSee full log: {log_file}"
except OSError:
pass
return ""
def run_python_as_user(script, args, user, log_file=None, cwd=None, timeout=1800):
"""Execute Python script as specified user via subprocess with uid/gid switch. Returns (returncode, stdout, stderr)."""
import pwd
pw = pwd.getpwnam(user)
uid, gid, home = pw.pw_uid, pw.pw_gid, pw.pw_dir
soft, hard = resolve_nofile_rlimit_for_user(user)
def _demote():
"""preexec_fn: switch to target user before child exec."""
apply_nofile_rlimit_before_setuid(soft, hard)
os.setgid(gid)
os.initgroups(user, gid)
os.setuid(uid)
env = os.environ.copy()
env.update({"HOME": home, "USER": user, "LOGNAME": user})
cmd_list = [sys.executable, script] + list(args)
work_dir = cwd or os.path.dirname(os.path.abspath(script))
try:
proc = subprocess.Popen(
cmd_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
cwd=work_dir, env=env, preexec_fn=_demote,
)
stdout_b, stderr_b = proc.communicate(timeout=timeout)
stdout = stdout_b.decode("utf-8", errors="replace").strip()
stderr = stderr_b.decode("utf-8", errors="replace").strip()
if log_file:
try:
log_dir = os.path.dirname(log_file)
if log_dir:
os.makedirs(log_dir, exist_ok=True)
with open(log_file, "a", encoding="utf-8") as log_fh:
combined = "\n".join(part for part in (stdout, stderr) if part)
if combined:
log_fh.write(combined + "\n")
except OSError:
pass
if proc.returncode != 0 and not stderr and log_file:
fallback = _tail_log(log_file, 20)
if fallback:
stderr = fallback
return proc.returncode, stdout, stderr
except subprocess.TimeoutExpired:
proc.kill()
proc.communicate()
return -1, "", f"Timeout after {timeout}s"
def _parse_backup_list_line(line):
"""Parse one line of backup file list. Returns (record_size, file_path) or (None, None) if invalid."""
line = line.strip()
if not line:
return None, None
cleaned = line.replace(" ", "").lstrip("[")
parts = cleaned.split("]", 1)
if len(parts) < 2:
return None, None
record_size = parts[0]
file_path = parts[1]
if not file_path:
return None, None
if "->" in file_path:
file_path = file_path.split("->")[0]
return record_size, file_path
def check_backup_files(backup_list_file, dest_dir, orig_dir):
"""Verify backup file integrity. Raises FileCheckError on failure."""
LOGGER.info(f"check backup files in {dest_dir} from {orig_dir}")
_do_file_check(backup_list_file, dest_dir, orig_dir, mode="backup")
def check_rollback_files(backup_list_file, dest_dir, orig_dir):
"""Verify rollback file integrity. Raises FileCheckError on failure."""
LOGGER.info(f"check rollback files in {dest_dir} from {orig_dir}")
_do_file_check(backup_list_file, dest_dir, orig_dir, mode="rollback")
class FileCheckError(Exception):
"""File verification failed."""
pass
def _do_file_check(backup_list_file, dest_dir, orig_dir, mode="backup"):
"""Unified file check: mode 'backup' checks dest exists, 'rollback' checks orig exists."""
with open(backup_list_file, "r") as f:
for line in f:
record_size, orig_path = _parse_backup_list_line(line)
if not orig_path:
continue
if not orig_path.startswith(orig_dir):
continue
if orig_path.startswith(os.path.join(orig_dir, "log")):
continue
relative_path = orig_path[len(orig_dir):]
dest_path = dest_dir + relative_path
if mode == "backup":
check_path = dest_path
else:
check_path = orig_path
if not os.path.exists(check_path):
msg = f"File not found: {orig_path} -> {dest_path}"
LOGGER.error(msg)
raise FileCheckError(msg)
if os.path.isfile(check_path):
orig_size = os.path.getsize(orig_path) if os.path.exists(orig_path) else 0
dest_size = os.path.getsize(dest_path) if os.path.exists(dest_path) else 0
if orig_size != dest_size:
msg = (f"File size mismatch: {orig_path}({orig_size}) "
f"-> {dest_path}({dest_size})")
LOGGER.error(msg)
raise FileCheckError(msg)
if record_size and str(dest_size) != str(record_size):
msg = (f"File size differs from record: "
f"recorded={record_size}, actual={dest_size}")
LOGGER.error(msg)
raise FileCheckError(msg)
class CGroupManager:
"""CGroup memory isolation management."""
def __init__(self, cgroup_path, mem_size_gb=10):
self.cgroup_path = cgroup_path
self.mem_size_gb = mem_size_gb
def create(self):
"""Create cgroup path."""
os.makedirs(self.cgroup_path, exist_ok=True)
LOGGER.info(f"cgroup path created: {self.cgroup_path}")
def configure(self, process_keyword="cms server -start"):
"""Set memory limit and add process."""
limit_file = os.path.join(self.cgroup_path, "memory.limit_in_bytes")
tasks_file = os.path.join(self.cgroup_path, "tasks")
with open(limit_file, "w") as f:
f.write(f"{self.mem_size_gb}G")
LOGGER.info(f"cgroup memory limit set to {self.mem_size_gb}G")
ret, stdout, _ = exec_popen(
f"ps -ef | grep '{process_keyword}' | grep -v grep | awk 'NR==1 {{print $2}}'"
)
if ret == 0 and stdout.strip():
pid = stdout.strip()
with open(tasks_file, "w") as f:
f.write(pid)
LOGGER.info(f"added pid {pid} to cgroup")
def clean(self):
"""Clean cgroup."""
if os.path.isdir(self.cgroup_path):
try:
os.rmdir(self.cgroup_path)
LOGGER.info(f"cgroup removed: {self.cgroup_path}")
except OSError as e:
LOGGER.warning(f"failed to remove cgroup: {e}")
else:
LOGGER.info("cgroup path does not exist, skip cleaning")
def setup(self, process_keyword="cms server -start"):
"""Full cgroup setup: clean -> create -> configure"""
try:
self.clean()
except Exception:
pass
self.create()
class IPTablesManager:
"""IPTables rule management."""
@staticmethod
def _get_iptables_path():
ret, stdout, _ = exec_popen("whereis iptables")
if ret == 0 and stdout:
path = stdout.split(":")[1].strip() if ":" in stdout else ""
return path
return ""
@staticmethod
def _rule_exists(chain, port):
"""Check if rule exists."""
ret, stdout, _ = exec_popen(
f"iptables -L {chain} -w 60 | grep ACCEPT | grep {port} | grep tcp | wc -l"
)
return ret == 0 and stdout.strip() != "0"
@classmethod
def accept(cls, cms_config_file):
"""Add iptables ACCEPT rules for CMS port."""
port = cls._read_port(cms_config_file)
if not port:
LOGGER.warning("cannot read CMS port, skip iptables")
return
if not cls._get_iptables_path():
LOGGER.info("iptables not found, skip")
return
LOGGER.info(f"adding iptables ACCEPT rules for port {port}")
for chain in ("INPUT", "FORWARD", "OUTPUT"):
if not cls._rule_exists(chain, port):
exec_popen(f"iptables -I {chain} -p tcp --sport {port} -j ACCEPT -w 60")
@classmethod
def delete(cls, cms_config_file):
"""Remove iptables ACCEPT rules for CMS port."""
port = cls._read_port(cms_config_file)
if not port:
return
if not cls._get_iptables_path():
return
LOGGER.info(f"deleting iptables rules for port {port}")
for chain in ("INPUT", "FORWARD", "OUTPUT"):
if cls._rule_exists(chain, port):
exec_popen(f"iptables -D {chain} -p tcp --sport {port} -j ACCEPT -w 60")
@staticmethod
def _read_port(config_file):
"""Read port from cms.ini."""
if not os.path.exists(config_file):
return ""
try:
with open(config_file, "r") as f:
for line in f:
if "_PORT" in line:
return line.split("=")[-1].strip()
except OSError:
pass
return ""
class ProcessManager:
"""Process management utility."""
CHECK_MAX_TIMES = 7
CHECK_INTERVAL = 5
@staticmethod
def get_pid(process_name):
"""Get process PID."""
cmd = (
f"ps -u $(id -un) -o pid=,args= | grep '{process_name}' | "
"grep -v grep | awk '{print $1}'"
)
ret, stdout, stderr = exec_popen(cmd)
if ret:
LOGGER.error(f"Failed to get pid for '{process_name}': {stderr}")
return ""
return stdout.strip()
@staticmethod
def kill_process(process_name):
"""Kill specified process."""
kill_cmd = (
f"proc_pid_list=$(ps -u $(id -un) -o pid=,args= | grep '{process_name}' | "
"grep -v grep | awk '{print $1}') && "
f'(if [ -n "$proc_pid_list" ]; then echo $proc_pid_list | xargs kill -9; fi)'
)
LOGGER.info(f"kill process: {process_name}")
run_cmd(kill_cmd, f"failed to kill {process_name}")
@classmethod
def ensure_stopped(cls, process_name, force_uninstall=None):
"""Ensure process stopped, wait up to CHECK_MAX_TIMES * CHECK_INTERVAL seconds."""
for i in range(cls.CHECK_MAX_TIMES):
pid = cls.get_pid(process_name)
if not pid:
return
LOGGER.info(f"check {i+1}/{cls.CHECK_MAX_TIMES}: {process_name} pid={pid}")
if i < cls.CHECK_MAX_TIMES - 1:
time.sleep(cls.CHECK_INTERVAL)
msg = f"Failed to stop {process_name} after {cls.CHECK_MAX_TIMES * cls.CHECK_INTERVAL}s"
LOGGER.error(msg)
if force_uninstall != "force":
raise RuntimeError(msg)
@staticmethod
def is_running(process_name):
"""Check if process is running."""
return bool(ProcessManager.get_pid(process_name))
@staticmethod
def clear_shm(shm_home="/dev/shm", shm_pattern="ograc.[0-9]*"):
"""Clear shared memory when ogracd not running. Args: shm_home, shm_pattern."""
shm_dir_name = os.path.basename(os.path.normpath(shm_home))
cmd = ("ps -eo args= | grep '[o]gracd' "
f"| grep '/dev/shm/{shm_dir_name}/'")
ret, stdout, _ = exec_popen(cmd)
if ret == 0 and stdout.strip():
LOGGER.info("ogracd is running, skip shm cleanup")
return
if not os.path.isdir(shm_home):
LOGGER.info(f"shm directory not found: {shm_home}, skip")
return
pattern = os.path.join(shm_home, shm_pattern)
count = 0
for f in glob_mod.glob(pattern):
try:
os.remove(f)
count += 1
except OSError:
pass
LOGGER.info(f"shared memory cleaned: {count} files from {shm_home}")
@staticmethod
def ensure_shm_dir(shm_home, user_and_group=""):
"""Ensure user shm subdir exists with correct permissions (0700). Args: shm_home, user_and_group."""
if not os.path.isdir(shm_home):
os.makedirs(shm_home, mode=0o700, exist_ok=True)
LOGGER.info(f"created shm directory: {shm_home} (mode=0700)")
if user_and_group:
exec_popen(f"chown {user_and_group} {shm_home}")
exec_popen(f"chmod 700 {shm_home}")
def ensure_dir(path, mode=0o750, owner=None):
"""Ensure directory exists and set permissions."""
os.makedirs(path, mode=mode, exist_ok=True)
if owner:
run_cmd(f"chown {owner} -hR {path}", f"failed to chown {path}")
def ensure_file(path, mode=0o640, owner=None):
"""Ensure file exists and set permissions."""
if not os.path.exists(path):
with open(path, "w"):
pass
os.chmod(path, mode)
if owner:
run_cmd(f"chown {owner} {path}", f"failed to chown {path}")
def safe_remove(path):
"""Safely remove file or directory."""
if os.path.isfile(path) or os.path.islink(path):
os.remove(path)
elif os.path.isdir(path):
import shutil
shutil.rmtree(path)
def copy_tree(src, dest, owner=None):
"""Copy directory tree."""
run_cmd(f"cp -arf {src} {dest}", f"failed to copy {src} to {dest}")
if owner:
run_cmd(f"chown -hR {owner} {dest}", f"failed to chown {dest}")
def read_version(versions_yml_path):
"""Read version from versions.yml."""
if not os.path.exists(versions_yml_path):
return ""
with open(versions_yml_path, "r") as f:
for line in f:
line = line.strip()
if line.startswith("Version:"):
return line.split(":")[1].strip()
return ""
def get_version_major(versions_yml_path):
"""Get major version (first digit)."""
version = read_version(versions_yml_path)
if version:
return int(version.split(".")[0])
return 0