@@ -113,6 +113,8 @@ extern "C" {
GGML_API ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend);
+ GGML_API void ggml_backend_set_priority(ggml_backend_t backend, int prio);
+
//
// Events
//
@@ -114,6 +114,9 @@ extern "C" {
void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
// wait for an event on on a different stream
void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
+
+ // (optional) backend context priority
+ void (*set_priority)(ggml_backend_t backend, int prio);
};
struct ggml_backend {
@@ -350,6 +350,12 @@ ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
return backend->device;
}
+void ggml_backend_set_priority(ggml_backend_t backend, int prio) {
+ if(backend->iface.set_priority != nullptr) {
+ backend->iface.set_priority(backend,prio);
+ }
+}
+
// backend copy
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
@@ -186,6 +186,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
+ /* .set_priority = */ NULL,
};
static ggml_guid_t ggml_backend_cpu_guid(void) {
@@ -1,6 +1,7 @@
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
find_package(CUDAToolkit)
+find_package(XSched REQUIRED)
if (CUDAToolkit_FOUND)
message(STATUS "CUDA Toolkit found")
@@ -102,12 +103,12 @@ if (CUDAToolkit_FOUND)
if (GGML_STATIC)
if (WIN32)
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt XSched::preempt XSched::halcuda)
else ()
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static XSched::preempt XSched::halcuda)
endif()
else()
- target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt)
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt XSched::preempt XSched::halcuda)
endif()
if (GGML_CUDA_NO_VMM)
@@ -34,6 +34,9 @@
#include "vendors/cuda.h"
#endif // defined(GGML_USE_HIP)
+#include "xsched/xsched.h"
+#include "xsched/cuda/hal.h"
+
#define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
@@ -800,6 +803,8 @@ struct ggml_backend_cuda_context {
std::unique_ptr<ggml_cuda_graph> cuda_graph;
+ int priority = 0;
+
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
@@ -811,6 +816,11 @@ struct ggml_backend_cuda_context {
if (streams[device][stream] == nullptr) {
ggml_cuda_set_device(device);
CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
+ HwQueueHandle hwqueue;
+ CudaQueueCreate(&hwqueue,streams[device][stream]);
+ XQueueHandle xqueue;
+ XQueueCreate(&xqueue, hwqueue, kPreemptLevelDeactivate, kQueueCreateFlagNone);
+ XHintPriority(xqueue, priority); // In XSched, lower number means lower priority
}
return streams[device][stream];
}
@@ -64,6 +64,9 @@
#include <string>
#include <vector>
+#include "xsched/xsched.h"
+#include "xsched/cuda/hal.h"
+
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
[[noreturn]]
@@ -2850,6 +2853,24 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
}
}
+static void ggml_backend_cuda_set_priority(ggml_backend_t backend, int prio) {
+ ggml_backend_cuda_context *cuda_ctx = (ggml_backend_cuda_context *)backend->context;
+ for (int device = 0; device < GGML_CUDA_MAX_DEVICES; device++) {
+ for (int idx = 0; idx < GGML_CUDA_MAX_STREAMS; idx++) {
+ auto stream = cuda_ctx->streams[device][idx];
+ if(stream == nullptr) {
+ continue;
+ }
+ HwQueueHandle hwqueue;
+ CudaQueueCreate(&hwqueue,stream);
+ XQueueHandle xqueue;
+ XQueueCreate(&xqueue, hwqueue, kPreemptLevelDeactivate, kQueueCreateFlagNone);
+ XHintPriority(xqueue, prio); // In XSched, lower number means lower priority
+ }
+ }
+ cuda_ctx->priority = prio;
+}
+
static const ggml_backend_i ggml_backend_cuda_interface = {
/* .get_name = */ ggml_backend_cuda_get_name,
/* .free = */ ggml_backend_cuda_free,
@@ -2864,6 +2885,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
/* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait,
+ /* .set_priority = */ ggml_backend_cuda_set_priority,
};
static ggml_guid_t ggml_backend_cuda_guid() {
@@ -1477,6 +1477,8 @@ extern "C" {
ggml_opt_epoch_callback callback_train,
ggml_opt_epoch_callback callback_eval);
+ LLAMA_API void llama_set_priority(struct llama_context *ctx, int prio);
+
#ifdef __cplusplus
}
#endif
@@ -2152,6 +2152,15 @@ void llama_context::opt_epoch(
llama_batch_free(batch);
}
+//
+// priority
+//
+void llama_context::set_priority(int prio) {
+ for(auto backend: backend_ptrs) {
+ ggml_backend_set_priority(backend, prio);
+ }
+}
+
//
// interface implementation
//
@@ -2843,3 +2852,7 @@ void llama_opt_epoch(
callback_train,
callback_eval);
}
+
+void llama_set_priority(struct llama_context *ctx, int prio) {
+ ctx->set_priority(prio);
+}
\ No newline at end of file
@@ -174,6 +174,10 @@ struct llama_context {
int64_t ndata_in_loop,
int64_t t_loop_start);
+ //
+ // priority
+ //
+ void set_priority(int prio);
private:
//
// output
@@ -9,7 +9,11 @@ CHAT=(
"Sure. The largest city in Europe is Moscow, the capital of Russia."
)
-INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
+PRIO=$1
+echo "priority: $PRIO"
+
+# INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
+INSTRUCTION="A chat between a human and an AI assistant. The assistant has no word limits."
trim() {
shopt -s extglob
@@ -41,20 +45,23 @@ N_KEEP=$(tokenize "${INSTRUCTION}" | wc -l)
chat_completion() {
PROMPT="$(trim_trailing "$(format_prompt "$1")")"
- DATA="$(echo -n "$PROMPT" | jq -Rs --argjson n_keep $N_KEEP '{
+ DATA="$(echo -n "$PROMPT" | jq -Rs --argjson n_keep $N_KEEP --argjson prio $PRIO '{
prompt: .,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: $n_keep,
- n_predict: 256,
+ n_predict: 4096,
cache_prompt: true,
stop: ["\n### Human:"],
- stream: true
+ stream: true,
+ priority: $prio,
}')"
ANSWER=''
+ echo $DATA
+
while IFS= read -r LINE; do
if [[ $LINE = data:* ]]; then
CONTENT="$(echo "${LINE:5}" | jq -r '.content')"
@@ -51,6 +51,12 @@ enum slot_state {
SLOT_STATE_GENERATING,
};
+enum server_task_prio {
+ SERVER_TASK_PRIO_NORMAL,
+ SERVER_TASK_PRIO_HIGH,
+ SERVER_TASK_PRIO_COUNT,
+};
+
enum server_state {
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
SERVER_STATE_READY, // Server is ready and model is loaded
@@ -215,6 +221,7 @@ struct server_task {
int index = -1; // used when there are multiple prompts (batch request)
server_task_type type;
+ server_task_prio prio;
// used by SERVER_TASK_TYPE_CANCEL
int id_target = -1;
@@ -238,7 +245,7 @@ struct server_task {
// used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_adapter_lora_info> set_lora;
- server_task(server_task_type type) : type(type) {}
+ server_task(server_task_type type) : type(type), prio(SERVER_TASK_PRIO_NORMAL) {}
static slot_params params_from_json_cmpl(
const llama_context * ctx,
@@ -1917,6 +1924,8 @@ struct server_context {
common_chat_templates_ptr chat_templates;
oaicompat_parser_options oai_parser_opt;
+ server_task_prio server_prio;
+
~server_context() {
mtmd_free(mctx);
@@ -1937,15 +1946,21 @@ struct server_context {
llama_batch_free(batch);
}
- bool load_model(const common_params & params) {
+ bool load_model(const common_params & params, llama_model *model_ref = nullptr) {
SRV_INF("loading model '%s'\n", params.model.path.c_str());
params_base = params;
- llama_init = common_init_from_params(params_base);
-
- model = llama_init.model.get();
- ctx = llama_init.context.get();
+ if(model_ref == nullptr) {
+ llama_init = common_init_from_params(params_base);
+
+ model = llama_init.model.get();
+ ctx = llama_init.context.get();
+ } else {
+ model = model_ref;
+ ctx = llama_init_from_model(model, common_context_params_to_llama(params));
+ llama_init.context.reset(ctx);
+ }
if (model == nullptr) {
SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str());
@@ -2117,6 +2132,11 @@ struct server_context {
};
}
+ void set_priority(server_task_prio prio) {
+ llama_set_priority(ctx, prio);
+ server_prio = prio;
+ }
+
server_slot * get_slot_by_id(int id) {
for (server_slot & slot : slots) {
if (slot.id == id) {
@@ -3115,6 +3135,7 @@ struct server_context {
if (prompt_tokens.empty()) {
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
+ SRV_INF("server prio %d\n",server_prio);
slot.release();
slot.print_timings();
send_final_response(slot);
@@ -3506,6 +3527,7 @@ struct server_context {
if (!process_token(result, slot)) {
// release slot because of stop condition
+ SRV_INF("server prio %d\n",server_prio);
slot.release();
slot.print_timings();
send_final_response(slot);
@@ -3605,6 +3627,7 @@ struct server_context {
if (!process_token(result, slot)) {
// release slot because of stop condition
+ SRV_INF("server prio %d",server_prio);
slot.release();
slot.print_timings();
send_final_response(slot);
@@ -3671,7 +3694,7 @@ int main(int argc, char ** argv) {
common_init();
// struct that contains llama context and inference
- server_context ctx_server;
+ server_context ctx_server[SERVER_TASK_PRIO_COUNT];
llama_backend_init();
llama_numa_init(params.numa);
@@ -3759,7 +3782,9 @@ int main(int argc, char ** argv) {
}
// Necessary similarity of prompt for slot selection
- ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
+ for(int i = 0; i < SERVER_TASK_PRIO_COUNT; i++) {
+ ctx_server[i].slot_prompt_similarity = params.slot_prompt_similarity;
+ }
//
// Middlewares
@@ -3857,17 +3882,17 @@ int main(int argc, char ** argv) {
}
// request slots data using task queue
- int task_id = ctx_server.queue_tasks.get_new_id();
+ int task_id = ctx_server[SERVER_TASK_PRIO_NORMAL].queue_tasks.get_new_id();
{
server_task task(SERVER_TASK_TYPE_METRICS);
task.id = task_id;
- ctx_server.queue_results.add_waiting_task_id(task_id);
- ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.add_waiting_task_id(task_id);
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_tasks.post(std::move(task), true); // high-priority task
}
// get the result
- server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
- ctx_server.queue_results.remove_waiting_task_id(task_id);
+ server_task_result_ptr result = ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.recv(task_id);
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
@@ -3896,17 +3921,17 @@ int main(int argc, char ** argv) {
}
// request slots data using task queue
- int task_id = ctx_server.queue_tasks.get_new_id();
+ int task_id = ctx_server[SERVER_TASK_PRIO_NORMAL].queue_tasks.get_new_id();
{
server_task task(SERVER_TASK_TYPE_METRICS);
task.id = task_id;
- ctx_server.queue_results.add_waiting_task_id(task_id);
- ctx_server.queue_tasks.post(std::move(task), true); // high-priority task
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.add_waiting_task_id(task_id);
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_tasks.post(std::move(task), true); // high-priority task
}
// get the result
- server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
- ctx_server.queue_results.remove_waiting_task_id(task_id);
+ server_task_result_ptr result = ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.recv(task_id);
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
@@ -3995,7 +4020,7 @@ int main(int argc, char ** argv) {
}
std::string filepath = params.slot_save_path + filename;
- int task_id = ctx_server.queue_tasks.get_new_id();
+ int task_id = ctx_server[SERVER_TASK_PRIO_NORMAL].queue_tasks.get_new_id();
{
server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
task.id = task_id;
@@ -4003,12 +4028,12 @@ int main(int argc, char ** argv) {
task.slot_action.filename = filename;
task.slot_action.filepath = filepath;
- ctx_server.queue_results.add_waiting_task_id(task_id);
- ctx_server.queue_tasks.post(std::move(task));
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.add_waiting_task_id(task_id);
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_tasks.post(std::move(task));
}
- server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
- ctx_server.queue_results.remove_waiting_task_id(task_id);
+ server_task_result_ptr result = ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.recv(task_id);
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
@@ -4027,7 +4052,7 @@ int main(int argc, char ** argv) {
}
std::string filepath = params.slot_save_path + filename;
- int task_id = ctx_server.queue_tasks.get_new_id();
+ int task_id = ctx_server[SERVER_TASK_PRIO_NORMAL].queue_tasks.get_new_id();
{
server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
task.id = task_id;
@@ -4035,12 +4060,12 @@ int main(int argc, char ** argv) {
task.slot_action.filename = filename;
task.slot_action.filepath = filepath;
- ctx_server.queue_results.add_waiting_task_id(task_id);
- ctx_server.queue_tasks.post(std::move(task));
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.add_waiting_task_id(task_id);
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_tasks.post(std::move(task));
}
- server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
- ctx_server.queue_results.remove_waiting_task_id(task_id);
+ server_task_result_ptr result = ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.recv(task_id);
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
@@ -4052,18 +4077,18 @@ int main(int argc, char ** argv) {
};
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
- int task_id = ctx_server.queue_tasks.get_new_id();
+ int task_id = ctx_server[SERVER_TASK_PRIO_NORMAL].queue_tasks.get_new_id();
{
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
task.id = task_id;
task.slot_action.slot_id = id_slot;
- ctx_server.queue_results.add_waiting_task_id(task_id);
- ctx_server.queue_tasks.post(std::move(task));
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.add_waiting_task_id(task_id);
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_tasks.post(std::move(task));
}
- server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
- ctx_server.queue_results.remove_waiting_task_id(task_id);
+ server_task_result_ptr result = ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.recv(task_id);
+ ctx_server[SERVER_TASK_PRIO_NORMAL].queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
@@ -4106,20 +4131,20 @@ int main(int argc, char ** argv) {
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
// this endpoint is publicly available, please only return what is safe to be exposed
json data = {
- { "default_generation_settings", ctx_server.default_generation_settings_for_props },
- { "total_slots", ctx_server.params_base.n_parallel },
- { "model_path", ctx_server.params_base.model.path },
+ { "default_generation_settings", ctx_server[SERVER_TASK_PRIO_NORMAL].default_generation_settings_for_props },
+ { "total_slots", ctx_server[SERVER_TASK_PRIO_NORMAL].params_base.n_parallel },
+ { "model_path", ctx_server[SERVER_TASK_PRIO_NORMAL].params_base.model.path },
{ "modalities", json{
- {"vision", ctx_server.oai_parser_opt.allow_image},
- {"audio", ctx_server.oai_parser_opt.allow_audio},
+ {"vision", ctx_server[SERVER_TASK_PRIO_NORMAL].oai_parser_opt.allow_image},
+ {"audio", ctx_server[SERVER_TASK_PRIO_NORMAL].oai_parser_opt.allow_audio},
} },
- { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) },
- { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)},
- { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)},
+ { "chat_template", common_chat_templates_source(ctx_server[SERVER_TASK_PRIO_NORMAL].chat_templates.get()) },
+ { "bos_token", common_token_to_piece(ctx_server[SERVER_TASK_PRIO_NORMAL].ctx, llama_vocab_bos(ctx_server[SERVER_TASK_PRIO_NORMAL].vocab), /* special= */ true)},
+ { "eos_token", common_token_to_piece(ctx_server[SERVER_TASK_PRIO_NORMAL].ctx, llama_vocab_eos(ctx_server[SERVER_TASK_PRIO_NORMAL].vocab), /* special= */ true)},
{ "build_info", build_info },
};
- if (ctx_server.params_base.use_jinja) {
- if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) {
+ if (ctx_server[SERVER_TASK_PRIO_NORMAL].params_base.use_jinja) {
+ if (auto tool_use_src = common_chat_templates_source(ctx_server[SERVER_TASK_PRIO_NORMAL].chat_templates.get(), "tool_use")) {
data["chat_template_tool_use"] = tool_use_src;
}
}
@@ -4128,7 +4153,7 @@ int main(int argc, char ** argv) {
};
const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
- if (!ctx_server.params_base.endpoint_props) {
+ if (!ctx_server[SERVER_TASK_PRIO_NORMAL].params_base.endpoint_props) {
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
@@ -4143,16 +4168,16 @@ int main(int argc, char ** argv) {
const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
json data = {
{
- "template", common_chat_templates_source(ctx_server.chat_templates.get()),
+ "template", common_chat_templates_source(ctx_server[SERVER_TASK_PRIO_NORMAL].chat_templates.get()),
},
{
"model_info", {
- { "llama.context_length", ctx_server.slots.back().n_ctx, },
+ { "llama.context_length", ctx_server[SERVER_TASK_PRIO_NORMAL].slots.back().n_ctx, },
}
},
{"modelfile", ""},
{"parameters", ""},
- {"template", common_chat_templates_source(ctx_server.chat_templates.get())},
+ {"template", common_chat_templates_source(ctx_server[SERVER_TASK_PRIO_NORMAL].chat_templates.get())},
{"details", {
{"parent_model", ""},
{"format", "gguf"},
@@ -4181,6 +4206,9 @@ int main(int argc, char ** argv) {
auto completion_id = gen_chatcmplid();
std::unordered_set<int> task_ids;
+ server_task_prio prio = json_value(data,"priority",SERVER_TASK_PRIO_NORMAL);
+ // printf("receive task with prio %d\n",prio);
+ SRV_INF("receive task with prio %d\n",prio);
try {
std::vector<server_task> tasks;
@@ -4190,13 +4218,13 @@ int main(int argc, char ** argv) {
// process files
mtmd::bitmaps bitmaps;
- const bool has_mtmd = ctx_server.mctx != nullptr;
+ const bool has_mtmd = ctx_server[prio].mctx != nullptr;
{
if (!has_mtmd && !files.empty()) {
throw std::runtime_error("This server does not support multimodal");
}
for (auto & file : files) {
- mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server.mctx, file.data(), file.size()));
+ mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(ctx_server[prio].mctx, file.data(), file.size()));
if (!bmp.ptr) {
throw std::runtime_error("Failed to load image or audio file");
}
@@ -4223,7 +4251,7 @@ int main(int argc, char ** argv) {
};
mtmd::input_chunks chunks(mtmd_input_chunks_init());
auto bitmaps_c_ptr = bitmaps.c_ptr();
- int32_t tokenized = mtmd_tokenize(ctx_server.mctx,
+ int32_t tokenized = mtmd_tokenize(ctx_server[prio].mctx,
chunks.ptr.get(),
&inp_txt,
bitmaps_c_ptr.data(),
@@ -4236,9 +4264,9 @@ int main(int argc, char ** argv) {
inputs.push_back(std::move(tmp));
} else {
// non-multimodal version
- auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
+ auto tokenized_prompts = tokenize_input_prompts(ctx_server[prio].vocab, prompt, true, true);
for (auto & p : tokenized_prompts) {
- auto tmp = server_tokens(p, ctx_server.mctx != nullptr);
+ auto tmp = server_tokens(p, ctx_server[prio].mctx != nullptr);
inputs.push_back(std::move(tmp));
}
}
@@ -4247,13 +4275,13 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < inputs.size(); i++) {
server_task task = server_task(type);
- task.id = ctx_server.queue_tasks.get_new_id();
+ task.id = ctx_server[prio].queue_tasks.get_new_id();
task.index = i;
task.prompt_tokens = std::move(inputs[i]);
task.params = server_task::params_from_json_cmpl(
- ctx_server.ctx,
- ctx_server.params_base,
+ ctx_server[prio].ctx,
+ ctx_server[prio].params_base,
data);
task.id_selected_slot = json_value(data, "id_slot", -1);
@@ -4266,8 +4294,8 @@ int main(int argc, char ** argv) {
}
task_ids = server_task::get_list_id(tasks);
- ctx_server.queue_results.add_waiting_tasks(tasks);
- ctx_server.queue_tasks.post(std::move(tasks));
+ ctx_server[prio].queue_results.add_waiting_tasks(tasks);
+ ctx_server[prio].queue_tasks.post(std::move(tasks));
} catch (const std::exception & e) {
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
return;
@@ -4276,7 +4304,7 @@ int main(int argc, char ** argv) {
bool stream = json_value(data, "stream", false);
if (!stream) {
- ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
+ ctx_server[prio].receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
if (results.size() == 1) {
// single result
res_ok(res, results[0]->to_json());
@@ -4292,10 +4320,10 @@ int main(int argc, char ** argv) {
res_error(res, error_data);
}, is_connection_closed);
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+ ctx_server[prio].queue_results.remove_waiting_task_ids(task_ids);
} else {
- const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
- ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
+ const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat, prio](size_t, httplib::DataSink & sink) {
+ ctx_server[prio].receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
json res_json = result->to_json();
if (res_json.is_array()) {
for (const auto & res : res_json) {
@@ -4322,8 +4350,8 @@ int main(int argc, char ** argv) {
return false;
};
- auto on_complete = [task_ids, &ctx_server] (bool) {
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+ auto on_complete = [task_ids, &ctx_server, prio] (bool) {
+ ctx_server[prio].queue_results.remove_waiting_task_ids(task_ids);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
@@ -4357,13 +4385,13 @@ int main(int argc, char ** argv) {
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
// check model compatibility
std::string err;
- if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
+ if (llama_vocab_fim_pre(ctx_server[SERVER_TASK_PRIO_NORMAL].vocab) == LLAMA_TOKEN_NULL) {
err += "prefix token is missing. ";
}
- if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
+ if (llama_vocab_fim_suf(ctx_server[SERVER_TASK_PRIO_NORMAL].vocab) == LLAMA_TOKEN_NULL) {
err += "suffix token is missing. ";
}
- if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
+ if (llama_vocab_fim_mid(ctx_server[SERVER_TASK_PRIO_NORMAL].vocab) == LLAMA_TOKEN_NULL) {
err += "middle token is missing. ";
}
if (!err.empty()) {
@@ -4408,18 +4436,19 @@ int main(int argc, char ** argv) {
}
data["input_extra"] = input_extra; // default to empty array if it's not exist
+ server_task_prio prio = json_value(data,"priority",SERVER_TASK_PRIO_NORMAL);
std::string prompt = json_value(data, "prompt", std::string());
- std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true);
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server[prio].vocab, prompt, false, true);
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
data["prompt"] = format_infill(
- ctx_server.vocab,
+ ctx_server[prio].vocab,
data.at("input_prefix"),
data.at("input_suffix"),
data.at("input_extra"),
- ctx_server.params_base.n_batch,
- ctx_server.params_base.n_predict,
- ctx_server.slots[0].n_ctx, // TODO: there should be a better way
- ctx_server.params_base.spm_infill,
+ ctx_server[prio].params_base.n_batch,
+ ctx_server[prio].params_base.n_predict,
+ ctx_server[prio].slots[0].n_ctx, // TODO: there should be a better way
+ ctx_server[prio].params_base.spm_infill,
tokenized_prompts[0]
);
@@ -4438,9 +4467,10 @@ int main(int argc, char ** argv) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files;
+ server_task_prio prio = json_value(body,"priority",SERVER_TASK_PRIO_NORMAL);
json data = oaicompat_chat_params_parse(
body,
- ctx_server.oai_parser_opt,
+ ctx_server[prio].oai_parser_opt,
files);
handle_completions_impl(
@@ -4456,9 +4486,10 @@ int main(int argc, char ** argv) {
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files; // dummy, unused
+ server_task_prio prio = json_value(body,"priority",SERVER_TASK_PRIO_NORMAL);
json data = oaicompat_chat_params_parse(
body,
- ctx_server.oai_parser_opt,
+ ctx_server[prio].oai_parser_opt,
files);
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};
@@ -4467,7 +4498,7 @@ int main(int argc, char ** argv) {
server_state current_state = state.load();
json model_meta = nullptr;
if (current_state == SERVER_STATE_READY) {
- model_meta = ctx_server.model_meta();
+ model_meta = ctx_server[SERVER_TASK_PRIO_NORMAL].model_meta();
}
json models = {
@@ -4516,11 +4547,11 @@ int main(int argc, char ** argv) {
const bool add_special = json_value(body, "add_special", false);
const bool with_pieces = json_value(body, "with_pieces", false);
- llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true);
+ llama_tokens tokens = tokenize_mixed(ctx_server[SERVER_TASK_PRIO_NORMAL].vocab, body.at("content"), add_special, true);
if (with_pieces) {
for (const auto& token : tokens) {
- std::string piece = common_token_to_piece(ctx_server.ctx, token);
+ std::string piece = common_token_to_piece(ctx_server[SERVER_TASK_PRIO_NORMAL].ctx, token);
json piece_json;
// Check if the piece is valid UTF-8
@@ -4554,7 +4585,7 @@ int main(int argc, char ** argv) {
std::string content;
if (body.count("tokens") != 0) {
const llama_tokens tokens = body.at("tokens");
- content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
+ content = tokens_to_str(ctx_server[SERVER_TASK_PRIO_NORMAL].ctx, tokens.cbegin(), tokens.cend());
}
const json data = format_detokenized_response(content);
@@ -4562,18 +4593,18 @@ int main(int argc, char ** argv) {
};
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
- if (!ctx_server.params_base.embedding) {
+ if (!ctx_server[SERVER_TASK_PRIO_NORMAL].params_base.embedding) {
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
- if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
+ if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server[SERVER_TASK_PRIO_NORMAL].ctx) == LLAMA_POOLING_TYPE_NONE) {
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return;
}
const json body = json::parse(req.body);
-
+ server_task_prio prio = json_value(body,"priority",SERVER_TASK_PRIO_NORMAL);
// for the shape of input/content, see tokenize_input_prompts()
json prompt;
if (body.count("input") != 0) {
@@ -4597,7 +4628,7 @@ int main(int argc, char ** argv) {
}
}
- auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
+ auto tokenized_prompts = tokenize_input_prompts(ctx_server[prio].vocab, prompt, true, true);
for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input
if (tokens.empty()) {
@@ -4615,9 +4646,9 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
- task.id = ctx_server.queue_tasks.get_new_id();
+ task.id = ctx_server[prio].queue_tasks.get_new_id();
task.index = i;
- task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server.mctx != nullptr);
+ task.prompt_tokens = server_tokens(tokenized_prompts[i], ctx_server[prio].mctx != nullptr);
// OAI-compat
task.params.oaicompat = oaicompat;
@@ -4626,12 +4657,12 @@ int main(int argc, char ** argv) {
}
task_ids = server_task::get_list_id(tasks);
- ctx_server.queue_results.add_waiting_tasks(tasks);
- ctx_server.queue_tasks.post(std::move(tasks));
+ ctx_server[prio].queue_results.add_waiting_tasks(tasks);
+ ctx_server[prio].queue_tasks.post(std::move(tasks));
}
// get the result
- ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
+ ctx_server[prio].receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
responses.push_back(res->to_json());
@@ -4641,7 +4672,7 @@ int main(int argc, char ** argv) {
error = true;
}, req.is_connection_closed);
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
+ ctx_server[prio].queue_results.remove_waiting_task_ids(task_ids);
if (error) {
return;
@@ -4663,13 +4694,13 @@ int main(int argc, char ** argv) {
};
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
- if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
+ if (!ctx_server[SERVER_TASK_PRIO_NORMAL].params_base.embedding || ctx_server[SERVER_TASK_PRIO_NORMAL].params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
}
const json body = json::parse(req.body);
-
+ server_task_prio prio = json_value(body,"priority",SERVER_TASK_PRIO_NORMAL);
// TODO: implement
//int top_n = 1;
//if (body.count("top_n") != 1) {
@@ -4703,7 +4734,7 @@ int main(int argc, char ** argv) {
return;
}
- llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
+ llama_tokens tokenized_query = tokenize_input_prompts(ctx_server[prio].vocab, query, /* add_special */ false, true)[0];
// create and queue the task
json responses = json::array();
@@ -4711,23 +4742,23 @@ int main(int argc, char ** argv) {
std::unordered_set<int> task_ids;
{
std::vector<server_task> tasks;
- auto tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
+ auto tokenized_docs = tokenize_input_prompts(ctx_server[prio].vocab, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) {
- auto tmp = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
+ auto tmp = format_rerank(ctx_server[prio].vocab, tokenized_query, tokenized_docs[i]);
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
- task.id = ctx_server.queue_tasks.get_new_id();
+ task.id = ctx_server[prio].queue_tasks.get_new_id();
task.index = i;
- task.prompt_tokens = server_tokens(tmp, ctx_server.mctx != nullptr);
+ task.prompt_tokens = server_tokens(tmp, ctx_server[prio].mctx != nullptr);
tasks.push_back(std::move(task));
}
task_ids = server_task::get_list_id(tasks);
- ctx_server.queue_results.add_waiting_tasks(tasks);
- ctx_server.queue_tasks.post(std::move(tasks));
+ ctx_server[prio].queue_results.add_waiting_tasks(tasks);
+ ctx_server[prio].queue_tasks.post(std::move(tasks));
}
- ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
+ ctx_server[prio].receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
for (auto & res : results) {
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
responses.push_back(res->to_json());
@@ -4753,7 +4784,7 @@ int main(int argc, char ** argv) {
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
json result = json::array();
- const auto & loras = ctx_server.params_base.lora_adapters;
+ const auto & loras = ctx_server[SERVER_TASK_PRIO_NORMAL].params_base.lora_adapters;
for (size_t i = 0; i < loras.size(); ++i) {
auto & lora = loras[i];
result.push_back({
@@ -4768,23 +4799,24 @@ int main(int argc, char ** argv) {
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);
+ server_task_prio prio = json_value(body,"priority",SERVER_TASK_PRIO_NORMAL);
if (!body.is_array()) {
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
return;
}
- int task_id = ctx_server.queue_tasks.get_new_id();
+ int task_id = ctx_server[prio].queue_tasks.get_new_id();
{
server_task task(SERVER_TASK_TYPE_SET_LORA);
task.id = task_id;
- task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body);
- ctx_server.queue_results.add_waiting_task_id(task_id);
- ctx_server.queue_tasks.post(std::move(task));
+ task.set_lora = parse_lora_request(ctx_server[prio].params_base.lora_adapters, body);
+ ctx_server[prio].queue_results.add_waiting_task_id(task_id);
+ ctx_server[prio].queue_tasks.post(std::move(task));
}
// get the result
- server_task_result_ptr result = ctx_server.queue_results.recv(task_id);
- ctx_server.queue_results.remove_waiting_task_id(task_id);
+ server_task_result_ptr result = ctx_server[prio].queue_results.recv(task_id);
+ ctx_server[prio].queue_results.remove_waiting_task_id(task_id);
if (result->is_error()) {
res_error(res, result->to_json());
@@ -4874,7 +4906,9 @@ int main(int argc, char ** argv) {
auto clean_up = [&svr, &ctx_server]() {
SRV_INF("%s: cleaning up before exit...\n", __func__);
svr->stop();
- ctx_server.queue_results.terminate();
+ for(int i = 0; i < SERVER_TASK_PRIO_COUNT; i++) {
+ ctx_server[i].queue_results.terminate();
+ }
llama_backend_free();
};
@@ -4915,34 +4949,44 @@ int main(int argc, char ** argv) {
// load the model
LOG_INF("%s: loading model\n", __func__);
- if (!ctx_server.load_model(params)) {
+ if (!ctx_server[SERVER_TASK_PRIO_NORMAL].load_model(params)) {
clean_up();
t.join();
LOG_ERR("%s: exiting due to model loading error\n", __func__);
return 1;
}
+ for(int i = SERVER_TASK_PRIO_NORMAL + 1; i < SERVER_TASK_PRIO_COUNT; i++) {
+ ctx_server[i].load_model(params,ctx_server[SERVER_TASK_PRIO_NORMAL].model);
+ }
- ctx_server.init();
+ for(int i = SERVER_TASK_PRIO_NORMAL; i < SERVER_TASK_PRIO_COUNT; i++) {
+ ctx_server[i].init();
+ ctx_server[i].set_priority((server_task_prio)i);
+ }
state.store(SERVER_STATE_READY);
LOG_INF("%s: model loaded\n", __func__);
// print sample chat example to make it clear which template is used
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
- common_chat_templates_source(ctx_server.chat_templates.get()),
- common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja).c_str());
-
- ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) {
- ctx_server.process_single_task(std::move(task));
- });
+ common_chat_templates_source(ctx_server[0].chat_templates.get()),
+ common_chat_format_example(ctx_server[0].chat_templates.get(), ctx_server[0].params_base.use_jinja).c_str());
- ctx_server.queue_tasks.on_update_slots([&ctx_server]() {
- ctx_server.update_slots();
- });
+ for(int i = 0; i < SERVER_TASK_PRIO_COUNT; i++) {
+ ctx_server[i].queue_tasks.on_new_task([&ctx_server, i](server_task && task) {
+ ctx_server[i].process_single_task(std::move(task));
+ });
+
+ ctx_server[i].queue_tasks.on_update_slots([&ctx_server, i]() {
+ ctx_server[i].update_slots();
+ });
+ }
shutdown_handler = [&](int) {
// this will unblock start_loop()
- ctx_server.queue_tasks.terminate();
+ for(int i = 0; i < SERVER_TASK_PRIO_COUNT; i++) {
+ ctx_server[i].queue_tasks.terminate();
+ }
};
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
@@ -4963,8 +5007,17 @@ int main(int argc, char ** argv) {
is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() :
string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str());
+ std::vector<std::thread> threads;
// this call blocks the main thread until queue_tasks.terminate() is called
- ctx_server.queue_tasks.start_loop();
+ for(int i = 0; i < SERVER_TASK_PRIO_COUNT; i++) {
+ threads.emplace_back([&ctx_server](int ind) {
+ ctx_server[ind].queue_tasks.start_loop();
+ },i);
+ }
+
+ for(auto &thread: threads) {
+ thread.join();
+ }
clean_up();
t.join();