"""
-------------------------------------------------------------------------
This file is part of the RAGSDK project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
RAGSDK is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
import asyncio
import json
import random
import uuid
from typing import List, Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
app = FastAPI(title="Mock Embedding API Server", version="1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ChatMessage(BaseModel):
role: str = Field(..., description="角色:user/assistant/system")
content: str = Field(..., description="消息内容")
class ChatCompletionRequest(BaseModel):
model: Optional[str] = Field("default-model", description="模型名称(可选)")
messages: List[ChatMessage]
temperature: Optional[float] = Field(1.0, ge=0, le=2)
max_tokens: Optional[int] = Field(1024, ge=1)
stream: Optional[bool] = Field(False)
class EmbeddingRequest(BaseModel):
model: Optional[str] = Field("default-embed-model", description="模型名称(可选)")
input: List[str] | str
encoding_format: Optional[str] = Field("float")
class RerankRequest(BaseModel):
query: str
texts: List[str]
class ClipEmbeddingRequest(BaseModel):
data: list
parameters: dict
def generate_mock_chat_response(messages: List[ChatMessage]) -> str:
last_user_msg = [msg for msg in messages if msg.role == "user"][-1].content
return f"模拟回复:{last_user_msg} 的相关内容..."
def generate_mock_embedding(text: str) -> List[float]:
random.seed(hash(text))
return [random.uniform(-1.0, 1.0) for _ in range(1024)]
def assemble_completion(request, case):
content = generate_mock_chat_response(request.messages)
if case == "qa_generate":
content = "Q1:2024年高考语文作文题目是什么?\n参考段落:新课标Ⅰ卷。"
result = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion",
"created": int(asyncio.get_event_loop().time()),
"model": request.model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": content},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": random.randint(10, 100), "completion_tokens": random.randint(10, 200),
"total_tokens": random.randint(20, 300)}
}
if case == "stream":
results = []
for i, content_char in enumerate(content):
reason = "stop" if i == len(content) - 1 else None
result = {
"id": f"chatcmpl-{uuid.uuid4()}",
"object": "chat.completion.chunk",
"created": int(asyncio.get_event_loop().time()),
"model": request.model,
"choices": [{
"index": i,
"delta": {"role": "assistant", "content": content_char},
"finish_reason": reason
}],
"usage": {"prompt_tokens": random.randint(10, 100), "completion_tokens": random.randint(10, 200),
"total_tokens": random.randint(20, 300)}
}
results.append(result)
return results
return result
@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
await asyncio.sleep(random.uniform(0.1, 1.0))
return assemble_completion(request, "common")
@app.post("/v1/chat/completions_qa_generate")
async def completions_qa_generate(request: ChatCompletionRequest):
return assemble_completion(request, "qa_generate")
@app.post("/v1/chat/completions_stream")
async def completions_stream(request: ChatCompletionRequest):
results = assemble_completion(request, "stream")
async def generate():
for result in results:
await asyncio.sleep(0.1)
yield json.dumps(result) + "\n"
return StreamingResponse(generate(), media_type="text/event-stream")
@app.post("/v1/embeddings")
async def embeddings(request: EmbeddingRequest):
await asyncio.sleep(random.uniform(0.05, 0.5))
texts = [request.input] if isinstance(request.input, str) else request.input
if len(texts) == 0:
raise HTTPException(status_code=400, detail="输入文本不能为空")
data = []
total_tokens = 0
for idx, text in enumerate(texts):
embedding = generate_mock_embedding(text)
tokens = len(text.split())
total_tokens += tokens
data.append({
"object": "embedding",
"index": idx,
"embedding": embedding,
})
return {
"object": "list",
"data": data,
"model": request.model,
"usage": {"prompt_tokens": total_tokens, "total_tokens": total_tokens}
}
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.post("/rerank")
async def rerank(request: RerankRequest):
query = request.query
texts = request.texts
if not (isinstance(texts, list) and len(texts) > 0):
raise HTTPException(status_code=400, detail="`texts` cannot be empty")
if not (isinstance(query, str) and len(query) > 0):
raise HTTPException(status_code=400, detail="`query` cannot be empty")
results = []
for idx, text in enumerate(texts):
random.seed(hash(text))
score = random.uniform(0.0, 1.0)
results.append({
"index": idx,
"score": score,
})
return results
@app.post("/encode_clip")
async def encode_clip(request: ClipEmbeddingRequest):
datas = request.data
results = []
for idx, data in enumerate(datas):
random.seed(hash(data.get("text")))
results.append({"embedding": [random.uniform(-1.0, 1.0) for _ in range(1024)]})
return {"data": results}
@app.post("/text2img")
async def text2img():
return "base64 mock image"
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)