"""
Sandbox API Service
代码执行沙箱服务 - 支持多种语言执行
端口: 8003
认证: Authorization: Bearer <key>
端点:
- POST /python/execute - Python代码执行
- POST /shell/execute - Shell命令执行
- POST /data/execute - 数据处理执行
"""
import os
import ast
import time
import subprocess
import platform
import traceback
from flask import Flask, request, jsonify
from flask_cors import CORS
from contextlib import redirect_stderr
import io
import logging
import sys
import pandas as pd
from etl.registry import get_reader
from utils.logger.eslog import CMRESHandler
from utils.common_utils import df_to_list
IS_WINDOWS = platform.system() == 'Windows'
if not IS_WINDOWS:
import resource
else:
resource = None
app = Flask(__name__)
CORS(app, origins=["*"])
SANDBOX_BEARER_KEY = os.environ.get('SANDBOX_BEARER_KEY', 'default-bearer-key')
DEFAULT_ALLOWED_MODULES = [
'pandas', 'numpy', 'math', 'datetime', 'json', 'pyecharts',
're', 'collections', 'itertools', 'operator', 'functools',
'base64', 'hashlib', 'hmac', 'uuid'
]
allowed_modules_str = os.environ.get('SANDBOX_ALLOWED_MODULES', '')
ALLOWED_MODULES = (
[m.strip() for m in allowed_modules_str.split(',') if m.strip()]
if allowed_modules_str else DEFAULT_ALLOWED_MODULES
)
DEFAULT_DANGEROUS_PATTERNS = [
'eval', 'exec', 'compile', 'open', 'file', 'input',
'globals', 'locals', 'dir', 'vars',
'delattr', 'reload', 'execfile',
'os.', 'sys.', 'subprocess.', 'socket.', 'urllib.', 'requests.',
'import os', 'import sys', 'import subprocess', 'import socket',
'multiprocessing', 'threading', 'ctypes', '__builtin__', 'builtins'
]
dangerous_patterns_str = os.environ.get('SANDBOX_DANGEROUS_PATTERNS', '')
DANGEROUS_PATTERNS = (
[p.strip() for p in dangerous_patterns_str.split(',') if p.strip()]
if dangerous_patterns_str else DEFAULT_DANGEROUS_PATTERNS
)
DEFAULT_DANGEROUS_COMMANDS = ['rm -rf', 'rm -fr', 'dd if=', 'mkfs', 'format', '>>', '&&', '||', '; ']
dangerous_commands_str = os.environ.get('SANDBOX_DANGEROUS_COMMANDS', '')
DANGEROUS_COMMANDS = (
[c.strip() for c in dangerous_commands_str.split(',') if c.strip()]
if dangerous_commands_str else DEFAULT_DANGEROUS_COMMANDS
)
def verify_bearer_token():
"""验证Bearer Token"""
auth_header = request.headers.get('Authorization', '')
if not auth_header.startswith('Bearer '):
return False, 'Missing or invalid Authorization header'
token = auth_header.split(' ')[1]
if token != SANDBOX_BEARER_KEY:
return False, 'Invalid bearer token'
return True, ''
def init_logger(logger_config: dict):
"""
初始化logger,根据配置返回 ES logger 或普通 logger
logger_config 格式:
{
'type': 'elasticsearch',
'hosts': 'http://elastic:pass@host:9200',
'index': 'task_logs',
'p_name': 'sandbox_task', # logger名称
'additional_fields': { # 额外的ES字段,如task_uuid等
'task_uuid': 'xxx',
'worker': 'xxx'
}
}
"""
p_name = logger_config.get('p_name', 'sandbox')
additional_fields = logger_config.get('additional_fields', {})
if not logger_config or logger_config.get('type') != 'elasticsearch':
logger = logging.getLogger(p_name)
logger.setLevel(logging.INFO)
logger.handlers.clear()
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
try:
hosts = logger_config.get('hosts', '')
index = logger_config.get('index', 'sandbox_logs')
if isinstance(hosts, str):
if '@' in hosts:
auth_part = hosts.split('@')[0].replace('http://', '').replace('https://', '')
if ':' in auth_part:
username, password = auth_part.split(':')
auth_details = (username, password)
auth_type = CMRESHandler.AuthType.BASIC_AUTH
host_part = hosts.split('@')[1]
hosts = [host_part]
else:
auth_type = CMRESHandler.AuthType.NO_AUTH
auth_details = ('', '')
hosts = [hosts]
else:
auth_type = CMRESHandler.AuthType.NO_AUTH
auth_details = ('', '')
hosts = [hosts]
elif isinstance(hosts, list):
auth_type = CMRESHandler.AuthType.NO_AUTH
auth_details = ('', '')
else:
auth_type = CMRESHandler.AuthType.NO_AUTH
auth_details = ('', '')
hosts = []
logger = logging.getLogger(p_name)
logger.setLevel(logging.INFO)
logger.handlers.clear()
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
es_fields = {
'source': 'sandbox',
**additional_fields
}
handler = CMRESHandler(
hosts=hosts,
auth_type=auth_type,
auth_details=auth_details,
es_index_name=index,
buffer_size=0,
es_additional_fields=es_fields
)
handler.formatter = formatter
logger.addHandler(handler)
return logger
except Exception as e:
print(f"Failed to initialize ES logger: {e}, using simple logger instead")
logger = logging.getLogger(p_name)
logger.setLevel(logging.INFO)
logger.handlers.clear()
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
def validate_python_code(code: str) -> tuple[bool, str]:
"""验证Python代码安全性"""
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if alias.name not in ALLOWED_MODULES:
return False, f"Import '{alias.name}' not allowed"
if isinstance(node, ast.ImportFrom):
if node.module:
top_module = node.module.split('.')[0]
if top_module not in ALLOWED_MODULES:
return False, f"Import from '{node.module}' not allowed (top module '{top_module}' not in allowed list)"
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
if node.func.id in ['eval', 'exec', 'compile']:
return False, f"Function '{node.func.id}' is not allowed"
except SyntaxError as e:
return False, f"Syntax error: {str(e)}"
for pattern in DANGEROUS_PATTERNS:
if pattern in code:
return False, f"Dangerous pattern detected: '{pattern}'"
return True, "Safe"
def execute_python(code: str, context: dict, timeout: int = 30, logger_config: dict = None) -> dict:
"""执行Python代码"""
try:
logger = init_logger(logger_config or {})
if not IS_WINDOWS and resource:
resource.setrlimit(resource.RLIMIT_CPU, (timeout, timeout))
resource.setrlimit(resource.RLIMIT_AS, (512 * 1024 * 1024, 512 * 1024 * 1024))
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
exec_globals = {
'logger': logger,
'__builtins__': __builtins__,
**context
}
exec_locals = {}
original_stdout = sys.stdout
sys.stdout = stdout_capture
try:
with redirect_stderr(stderr_capture):
exec(code, exec_globals, exec_locals)
finally:
sys.stdout = original_stdout
result = exec_locals.get('result', None)
if isinstance(result, dict) and 'value' in result and isinstance(result['value'], pd.DataFrame):
result['value'] = df_to_list(result['value'])
return {
'success': True,
'result': result,
'output': stdout_capture.getvalue(),
'error': stderr_capture.getvalue()
}
except MemoryError:
return {'success': False, 'error': 'Memory limit exceeded'}
except Exception as e:
return {'success': False, 'error': f'{type(e).__name__}: {str(e)}\n{traceback.format_exc()}'}
def execute_shell(command: str, timeout: int = 30, logger_config: dict = None) -> dict:
"""执行Shell命令,参考 shell_tasks.py 的实现"""
try:
logger = init_logger(logger_config or {})
for cmd in DANGEROUS_COMMANDS:
if cmd in command:
logger.error(f'Dangerous shell command detected: {cmd}')
return {
'success': False,
'error': f'Dangerous shell command detected: {cmd}'
}
logger.info(f'Executing command: {command}')
process = subprocess.Popen(
command,
shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True
)
start_time = time.time()
output_lines = []
while True:
line = process.stdout.readline()
if line == '' and process.poll() is not None:
break
if line:
clean_line = line.rstrip()
if clean_line:
logger.info(clean_line)
output_lines.append(line)
if time.time() - start_time > timeout:
logger.warning(f'Command timeout after {timeout} seconds, terminating...')
process.terminate()
process.wait(timeout=5)
return {
'success': False,
'error': f'Command timeout after {timeout} seconds'
}
exit_code = process.wait()
if exit_code == 0:
logger.info(f'Command completed successfully')
else:
logger.error(f'Command exited with code {exit_code}')
return {
'success': exit_code == 0,
'result': {
'exit_code': exit_code,
'output': ''.join(output_lines)
},
'error': None if exit_code == 0 else f'Command exited with code {exit_code}'
}
except Exception as e:
return {'success': False, 'error': f'{type(e).__name__}: {str(e)}\n{traceback.format_exc()}'}
def execute_data(code: str, config: dict, timeout: int = 600) -> dict:
"""执行数据处理代码,根据数据模型配置实例化reader"""
try:
extract_info = config.get('extract_info', {})
reader = None
if extract_info:
flag, reader = get_reader(extract_info)
if not flag:
raise RuntimeError(f'reader初始化失败:{reader}')
else:
raise RuntimeError(f'reade信息获取失败')
context = {
'reader': reader
}
return execute_python(code, context, timeout)
except Exception as e:
return {'success': False, 'error': f'Data execution error: {str(e)}\n{traceback.format_exc()}'}
@app.route('/python/execute', methods=['POST'])
def python_execute():
"""Python代码执行端点"""
valid, msg = verify_bearer_token()
if not valid:
return jsonify({'success': False, 'error': msg}), 401
data = request.get_json()
if not data:
return jsonify({'success': False, 'error': 'No data provided'}), 400
code = data.get('code', '')
context = data.get('context', {})
timeout = data.get('timeout', 30)
logger_config = data.get('logger_config', {})
if not code:
return jsonify({'success': False, 'error': 'No code provided'}), 400
is_safe, reason = validate_python_code(code)
if not is_safe:
return jsonify({'success': False, 'error': f'Code validation failed: {reason}'}), 400
start_time = time.time()
try:
result = execute_python(code, context, timeout, logger_config)
result['execution_time'] = time.time() - start_time
return jsonify(result)
except Exception as e:
return jsonify({
'success': False,
'error': f'Execution error: {str(e)}\n{traceback.format_exc()}',
'execution_time': time.time() - start_time
}), 500
@app.route('/shell/execute', methods=['POST'])
def shell_execute():
"""Shell命令执行端点"""
valid, msg = verify_bearer_token()
if not valid:
return jsonify({'success': False, 'error': msg}), 401
data = request.get_json()
if not data:
return jsonify({'success': False, 'error': 'No data provided'}), 400
command = data.get('command', '')
timeout = data.get('timeout', 30)
logger_config = data.get('logger_config', {})
if not command:
return jsonify({'success': False, 'error': 'No command provided'}), 400
start_time = time.time()
try:
result = execute_shell(command, timeout, logger_config)
result['execution_time'] = time.time() - start_time
return jsonify(result)
except Exception as e:
return jsonify({
'success': False,
'error': f'Execution error: {str(e)}\n{traceback.format_exc()}',
'execution_time': time.time() - start_time
}), 500
@app.route('/data/execute', methods=['POST'])
def data_execute():
"""数据处理执行端点"""
valid, msg = verify_bearer_token()
if not valid:
return jsonify({'success': False, 'error': msg}), 401
data = request.get_json()
if not data:
return jsonify({'success': False, 'error': 'No data provided'}), 400
code = data.get('code', '')
config = data.get('config', {})
timeout = data.get('timeout', 60)
if not code:
return jsonify({'success': False, 'error': 'No code provided'}), 400
is_safe, reason = validate_python_code(code)
if not is_safe:
return jsonify({'success': False, 'error': f'Code validation failed: {reason}'}), 400
start_time = time.time()
try:
result = execute_data(code, config, timeout)
result['execution_time'] = time.time() - start_time
return jsonify(result)
except Exception as e:
return jsonify({
'success': False,
'error': f'Execution error: {str(e)}\n{traceback.format_exc()}',
'execution_time': time.time() - start_time
}), 500
@app.route('/health', methods=['GET'])
def health():
"""健康检查端点"""
return jsonify({
'status': 'healthy',
'timestamp': time.time(),
'version': '1.0.0',
'endpoints': ['/python/execute', '/shell/execute', '/data/execute']
})
if __name__ == '__main__':
port = int(os.environ.get('SANDBOX_PORT', 8003))
debug = os.environ.get('SANDBOX_DEBUG', 'false').lower() == 'true'
print(f"""
╔══════════════════════════════════════╗
║ Sandbox API Service ║
╠══════════════════════════════════════╣
║ Port: {port:<34} ║
║ Debug: {debug:<34} ║
║ Auth: Bearer Token ║
║ Endpoints: ║
║ - POST /python/execute ║
║ - POST /shell/execute ║
║ - POST /data/execute ║
╚══════════════════════════════════════╝
""")
app.run(host='0.0.0.0', port=port, debug=debug)