"""ContextEngine 压力测试脚本
测试大规模写入场景下的写入链路有效性和一致性。
运行方式:
python3 scripts/stress_test.py --help
测试场景:
1. 单账户大规模写入 - 1000 条记忆写入
2. 多账户并发写入 - 10 账户 x 100 条
3. 相同 routing_key 并发写入 - 测试 merge 一致性
4. Outbox 处理压力测试 - 大量索引事件
"""
import argparse
import asyncio
import os
import sys
import time
import uuid
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import List, Dict, Any
sys.path.insert(0, str(Path(__file__).parent.parent))
from core.models import RequestContext, CandidateMemory
from fs.agfs_adapter import AGFSContextFS
from pyagfs import AGFSClient
from service.api import MemoryWriteAPI
from providers.llm import MockLLM
from providers.embedder import MockEmbedder
from providers.vector_index import InMemoryVectorIndex
from commit import OutboxStore
@dataclass
class StressTestResult:
"""压力测试结果"""
name: str
total_operations: int
successful: int
failed: int
duration_seconds: float
ops_per_second: float
errors: List[str]
class StressTestRunner:
"""压力测试运行器"""
def __init__(self, agfs_base_url: str = "http://localhost:1833"):
self.agfs_base_url = agfs_base_url
self.client = AGFSClient(api_base_url=agfs_base_url)
self.fs = AGFSContextFS(client=self.client, mount_prefix="/local/stress_test")
self.llm = MockLLM()
self.embedder = MockEmbedder(dimension=1536)
self.vector_index = InMemoryVectorIndex(dimension=1536)
self.outbox_store = OutboxStore(client=self.client, fs=self.fs, mount_prefix="/local/stress_test")
self.api = MemoryWriteAPI(fs=self.fs, llm=self.llm, outbox_store=self.outbox_store)
try:
self.client.mkdir("/local/stress_test")
except:
pass
def create_context(self, account_id: str, user_id: str) -> RequestContext:
"""创建 RequestContext"""
return RequestContext(
account_id=account_id,
user_id=user_id,
agent_id="stress-test-agent",
session_id=str(uuid.uuid4()),
trace_id=str(uuid.uuid4()),
)
def test_single_account_large_scale(self, num_writes: int = 1000) -> StressTestResult:
"""测试 1: 单账户大规模写入
验证:
- 大量写入不会崩溃
- 写入吞吐量可接受
- 数据一致性保持
"""
print(f"\n{'='*60}")
print(f"测试 1: 单账户大规模写入 ({num_writes} 条)")
print(f"{'='*60}")
ctx = self.create_context("stress-test-account", "user-1")
successful = 0
failed = 0
errors = []
start_time = time.time()
for i in range(num_writes):
try:
candidate = CandidateMemory(
category="preference",
owner_scope="user",
routing_key=f"pref_{i % 100}",
abstract=f"Preference {i}",
overview=f"## Preference {i}\n\nTest content",
content=f"This is test preference number {i} for stress testing.",
confidence=0.9,
)
result = self.api.write_memory(candidate, ctx)
successful += 1
if (i + 1) % 100 == 0:
elapsed = time.time() - start_time
ops_per_sec = (i + 1) / elapsed
print(f" 进度: {i+1}/{num_writes} ({ops_per_sec:.1f} ops/s)")
except Exception as e:
failed += 1
errors.append(str(e))
duration = time.time() - start_time
ops_per_sec = num_writes / duration
print(f"\n结果:")
print(f" 总操作数: {num_writes}")
print(f" 成功: {successful}")
print(f" 失败: {failed}")
print(f" 耗时: {duration:.2f} 秒")
print(f" 吞吐量: {ops_per_sec:.2f} ops/s")
return StressTestResult(
name="single_account_large_scale",
total_operations=num_writes,
successful=successful,
failed=failed,
duration_seconds=duration,
ops_per_second=ops_per_sec,
errors=errors,
)
def test_multi_account_concurrent(self, num_accounts: int = 10, writes_per_account: int = 100) -> StressTestResult:
"""测试 2: 多账户并发写入
验证:
- 多账户隔离正确
- 并发写入一致性
- 无跨账户数据泄漏
"""
print(f"\n{'='*60}")
print(f"测试 2: 多账户并发写入 ({num_accounts} 账户 x {writes_per_account} 条)")
print(f"{'='*60}")
successful = 0
failed = 0
errors = []
lock = threading.Lock()
def write_for_account(account_idx: int):
nonlocal successful, failed
ctx = self.create_context(f"stress-account-{account_idx}", f"user-{account_idx}")
for i in range(writes_per_account):
try:
candidate = CandidateMemory(
category="entity",
owner_scope="user",
routing_key=f"entity_{i}",
abstract=f"Entity {i} for account {account_idx}",
overview=f"## Entity {i}",
content=f"Entity data for account {account_idx}",
confidence=0.9,
)
api = MemoryWriteAPI(fs=self.fs, llm=self.llm, outbox_store=self.outbox_store)
api.write_memory(candidate, ctx)
with lock:
successful += 1
except Exception as e:
with lock:
failed += 1
errors.append(f"Account-{account_idx}: {str(e)}")
start_time = time.time()
with ThreadPoolExecutor(max_workers=num_accounts) as executor:
futures = [executor.submit(write_for_account, i) for i in range(num_accounts)]
for future in as_completed(futures):
future.result()
duration = time.time() - start_time
total_ops = num_accounts * writes_per_account
ops_per_sec = total_ops / duration
print(f"\n结果:")
print(f" 总操作数: {total_ops}")
print(f" 成功: {successful}")
print(f" 失败: {failed}")
print(f" 耗时: {duration:.2f} 秒")
print(f" 吞吐量: {ops_per_sec:.2f} ops/s")
return StressTestResult(
name="multi_account_concurrent",
total_operations=total_ops,
successful=successful,
failed=failed,
duration_seconds=duration,
ops_per_second=ops_per_sec,
errors=errors,
)
def test_concurrent_merge_consistency(self, num_threads: int = 10, writes_per_thread: int = 50) -> StressTestResult:
"""测试 3: 相同 routing_key 并发写入 (Merge 一致性)
验证:
- 并发 merge 不会丢失数据
- 最终状态一致
- 无竞态条件导致的错误
"""
print(f"\n{'='*60}")
print(f"测试 3: 并发 Merge 一致性 ({num_threads} 线程 x {writes_per_thread} 次)")
print(f"{'='*60}")
routing_key = "concurrent_test_preference"
ctx = self.create_context("consistency-test", "user-1")
successful = 0
failed = 0
errors = []
lock = threading.Lock()
def write_concurrent(thread_id: int):
nonlocal successful, failed
for i in range(writes_per_thread):
try:
candidate = CandidateMemory(
category="preference",
owner_scope="user",
routing_key=routing_key,
abstract=f"Concurrent write {thread_id}-{i}",
overview=f"## Concurrent Write\n\nThread {thread_id}, iteration {i}",
content=f"Content from thread {thread_id}, iteration {i}",
confidence=0.9,
)
api = MemoryWriteAPI(fs=self.fs, llm=self.llm, outbox_store=self.outbox_store)
api.write_memory(candidate, ctx)
with lock:
successful += 1
except Exception as e:
with lock:
failed += 1
errors.append(f"Thread-{thread_id}: {str(e)}")
start_time = time.time()
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = [executor.submit(write_concurrent, i) for i in range(num_threads)]
for future in as_completed(futures):
future.result()
duration = time.time() - start_time
total_ops = num_threads * writes_per_thread
ops_per_sec = total_ops / duration
try:
uri = f"ctx://consistency-test/users/user-1/memories/preferences/{routing_key}"
exists = self.fs.exists(uri, ctx)
print(f"\n 最终节点状态: {'存在' if exists else '不存在'}")
except Exception as e:
print(f"\n 验证节点时出错: {e}")
print(f"\n结果:")
print(f" 总操作数: {total_ops}")
print(f" 成功: {successful}")
print(f" 失败: {failed}")
print(f" 耗时: {duration:.2f} 秒")
print(f" 吞吐量: {ops_per_sec:.2f} ops/s")
return StressTestResult(
name="concurrent_merge_consistency",
total_operations=total_ops,
successful=successful,
failed=failed,
duration_seconds=duration,
ops_per_second=ops_per_sec,
errors=errors,
)
def test_outbox_processing_stress(self, num_events: int = 500) -> StressTestResult:
"""测试 4: Outbox 处理压力测试
验证:
- 大量 OutboxEvent 正确处理
- 索引同步不丢失
- DLQ 机制正确工作
"""
print(f"\n{'='*60}")
print(f"测试 4: Outbox 处理压力测试 ({num_events} 事件)")
print(f"{'='*60}")
from index.outbox_worker import OutboxWorker, create_upsert_event
from core.models import IndexRecord
successful = 0
failed = 0
errors = []
start_time = time.time()
for i in range(num_events):
try:
ctx = self.create_context("outbox-test", f"user-{i % 10}")
candidate = CandidateMemory(
category="preference",
owner_scope="user",
routing_key=f"outbox_pref_{i % 50}",
abstract=f"Outbox test {i}",
overview="## Test",
content=f"Content {i}",
confidence=0.9,
)
self.api.write_memory(candidate, ctx)
successful += 1
if (i + 1) % 100 == 0:
elapsed = time.time() - start_time
ops_per_sec = (i + 1) / elapsed
print(f" 进度: {i+1}/{num_events} ({ops_per_sec:.1f} ops/s)")
except Exception as e:
failed += 1
errors.append(str(e))
print(f"\n 处理 Outbox 事件...")
worker = OutboxWorker(vector_index=self.vector_index, embedder=self.embedder)
process_start = time.time()
stats = worker.run_once(self.outbox_store, ["outbox-test"], worker_id="stress-worker")
process_duration = time.time() - process_start
duration = time.time() - start_time
ops_per_sec = num_events / duration
print(f"\nOutbox 处理统计:")
print(f" 处理: {stats['processed']}")
print(f" 成功: {stats['succeeded']}")
print(f" 失败: {stats['failed']}")
print(f" 跳过: {stats['skipped']}")
print(f" 处理耗时: {process_duration:.2f} 秒")
print(f"\n结果:")
print(f" 总事件数: {num_events}")
print(f" 写入成功: {successful}")
print(f" 写入失败: {failed}")
print(f" 总耗时: {duration:.2f} 秒")
print(f" 吞吐量: {ops_per_sec:.2f} ops/s")
return StressTestResult(
name="outbox_processing_stress",
total_operations=num_events,
successful=successful,
failed=failed,
duration_seconds=duration,
ops_per_second=ops_per_sec,
errors=errors,
)
def run_all_tests(self) -> List[StressTestResult]:
"""运行所有压力测试"""
results = []
print(f"\n{'='*60}")
print(f"ContextEngine 压力测试套件")
print(f"开始时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"{'='*60}")
results.append(self.test_single_account_large_scale(num_writes=500))
results.append(self.test_multi_account_concurrent(num_accounts=10, writes_per_account=50))
results.append(self.test_concurrent_merge_consistency(num_threads=10, writes_per_thread=20))
results.append(self.test_outbox_processing_stress(num_events=200))
return results
def print_summary(self, results: List[StressTestResult]):
"""打印测试摘要"""
print(f"\n{'='*60}")
print(f"压力测试摘要")
print(f"{'='*60}\n")
total_ops = sum(r.total_operations for r in results)
total_successful = sum(r.successful for r in results)
total_failed = sum(r.failed for r in results)
for r in results:
status = f"{r.successful}/{r.total_operations}"
print(f" {r.name}")
print(f" 状态: {status}")
print(f" 吞吐量: {r.ops_per_second:.2f} ops/s")
if r.errors:
print(f" 错误数: {len(r.errors)}")
print()
print(f"总计:")
print(f" 总操作数: {total_ops}")
print(f" 成功: {total_successful}")
print(f" 失败: {total_failed}")
print(f" 成功率: {total_successful/total_ops*100:.1f}%")
if total_failed > 0:
print(f"\n失败详情:")
for r in results:
if r.errors:
print(f"\n {r.name}:")
for e in r.errors[:5]:
print(f" - {e}")
if len(r.errors) > 5:
print(f" ... 还有 {len(r.errors) - 5} 个错误")
def main():
parser = argparse.ArgumentParser(description="ContextEngine 压力测试")
parser.add_argument("--test", choices=["all", "single", "multi", "merge", "outbox"],
default="all", help="要运行的测试")
parser.add_argument("--agfs-url", default="http://localhost:1833",
help="AGFS 服务地址")
parser.add_argument("--writes", type=int, default=500,
help="单测试写入数量")
args = parser.parse_args()
runner = StressTestRunner(agfs_base_url=args.agfs_url)
if args.test == "all":
results = runner.run_all_tests()
runner.print_summary(results)
elif args.test == "single":
result = runner.test_single_account_large_scale(num_writes=args.writes)
runner.print_summary([result])
elif args.test == "multi":
result = runner.test_multi_account_concurrent(num_accounts=10, writes_per_account=args.writes // 10)
runner.print_summary([result])
elif args.test == "merge":
result = runner.test_concurrent_merge_consistency(num_threads=10, writes_per_thread=args.writes // 10)
runner.print_summary([result])
elif args.test == "outbox":
result = runner.test_outbox_processing_stress(num_events=args.writes)
runner.print_summary([result])
if __name__ == "__main__":
main()