/*
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * openFuyao is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

package server

import (
	"bytes"
	"context"
	"encoding/json"
	"errors"
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"

	"github.com/go-logr/logr/testr"

	"gitcode.com/openFuyao/cache-indexer/pkg/apis"
	"gitcode.com/openFuyao/cache-indexer/pkg/config"
)

const (
	testReadyPath      = "/readyz"
	testLargeBodyLimit = int64(8)
)

type stubScore struct {
	got struct {
		mc       apis.ModelContext
		ips      []string
		tokenIDs []int32
	}
	out []apis.ServerScore
	err error
}

func (s *stubScore) Compute(_ context.Context, mc apis.ModelContext, ips []string, t []int32) ([]apis.ServerScore, error) {
	s.got.mc = mc
	s.got.ips = ips
	s.got.tokenIDs = t
	return s.out, s.err
}

func newTestMux(t *testing.T, st *stubScore) http.Handler {
	t.Helper()
	return newMux(testr.New(t), st, normalizeHTTPConfig(config.HTTPConfig{}))
}

func TestHealthz_OK(t *testing.T) {
	mux := newTestMux(t, &stubScore{})
	rr := httptest.NewRecorder()
	mux.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/healthz", nil))
	if rr.Code != http.StatusOK {
		t.Fatalf("want 200 got %d", rr.Code)
	}
}

func TestReadyz_OK(t *testing.T) {
	mux := newTestMux(t, &stubScore{})
	rr := httptest.NewRecorder()
	mux.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, testReadyPath, nil))
	if rr.Code != http.StatusOK {
		t.Fatalf("want 200 got %d", rr.Code)
	}
}

func TestHitRate_OK(t *testing.T) {
	st := &stubScore{out: []apis.ServerScore{
		{ServerIP: "1.1.1.1", L1HitRatio: 0.75, L3HitRatio: 0},
		{ServerIP: "2.2.2.2", L1HitRatio: 0.25, L3HitRatio: 0},
	}}
	mux := newTestMux(t, st)
	body := `{"server_ip":["1.1.1.1","2.2.2.2"],"body":{"token_ids":[1,2,3,4],"cache_salt":"s","block_size":4}}`
	req := httptest.NewRequest(http.MethodPost, "/kv-cache/hit-rate", strings.NewReader(body))
	rr := httptest.NewRecorder()
	mux.ServeHTTP(rr, req)
	if rr.Code != http.StatusOK {
		t.Fatalf("want 200 got %d body=%s", rr.Code, rr.Body.String())
	}
	if got, want := rr.Header().Get("Content-Type"), "application/json"; got != want {
		t.Fatalf("ct=%s", got)
	}
	var resp hitRateResponse
	if err := json.Unmarshal(rr.Body.Bytes(), &resp); err != nil {
		t.Fatalf("unmarshal: %v", err)
	}
	if resp.Status != 0 || len(resp.ServerScoreList) != 2 {
		t.Fatalf("resp=%+v", resp)
	}
	if resp.ServerScoreList[0].ServerIP != "1.1.1.1" || resp.ServerScoreList[0].L1HitRatio != 0.75 {
		t.Fatalf("first=%+v", resp.ServerScoreList[0])
	}
	if st.got.mc.BlockSize != 4 || st.got.mc.CacheSalt != "s" {
		t.Fatalf("modelCtx not propagated: %+v", st.got.mc)
	}
	if len(st.got.tokenIDs) != 4 || st.got.tokenIDs[0] != 1 {
		t.Fatalf("tokens not propagated: %+v", st.got.tokenIDs)
	}
}

func TestHitRate_BadJSON(t *testing.T) {
	mux := newTestMux(t, &stubScore{})
	rr := httptest.NewRecorder()
	mux.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, "/kv-cache/hit-rate", strings.NewReader("not json")))
	if rr.Code != http.StatusBadRequest {
		t.Fatalf("want 400 got %d", rr.Code)
	}
}

func TestHitRate_EmptyBody(t *testing.T) {
	mux := newTestMux(t, &stubScore{})
	rr := httptest.NewRecorder()
	mux.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, "/kv-cache/hit-rate", strings.NewReader("")))
	if rr.Code != http.StatusBadRequest {
		t.Fatalf("want 400 got %d", rr.Code)
	}
}

func TestHitRate_EmptyServerIP(t *testing.T) {
	mux := newTestMux(t, &stubScore{})
	rr := httptest.NewRecorder()
	body := `{"server_ip":[],"body":{"token_ids":[1],"block_size":1}}`
	mux.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, "/kv-cache/hit-rate", strings.NewReader(body)))
	if rr.Code != http.StatusBadRequest {
		t.Fatalf("want 400 got %d", rr.Code)
	}
}

func TestHitRate_BadBlockSize(t *testing.T) {
	mux := newTestMux(t, &stubScore{})
	rr := httptest.NewRecorder()
	body := `{"server_ip":["x"],"body":{"token_ids":[1],"block_size":0}}`
	mux.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, "/kv-cache/hit-rate", strings.NewReader(body)))
	if rr.Code != http.StatusBadRequest {
		t.Fatalf("want 400 got %d", rr.Code)
	}
}

func TestHitRate_NegativeBlockSize(t *testing.T) {
	mux := newTestMux(t, &stubScore{})
	rr := httptest.NewRecorder()
	body := `{"server_ip":["x"],"body":{"token_ids":[1],"block_size":-1}}`
	mux.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, "/kv-cache/hit-rate", strings.NewReader(body)))
	if rr.Code != http.StatusBadRequest {
		t.Fatalf("want 400 got %d", rr.Code)
	}
}

func TestHitRate_MethodNotAllowed(t *testing.T) {
	mux := newTestMux(t, &stubScore{})
	rr := httptest.NewRecorder()
	mux.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/kv-cache/hit-rate", nil))
	if rr.Code != http.StatusMethodNotAllowed {
		t.Fatalf("want 405 got %d", rr.Code)
	}
}

func TestHitRate_ScoreError_500(t *testing.T) {
	st := &stubScore{err: errors.New("boom")}
	mux := newTestMux(t, st)
	rr := httptest.NewRecorder()
	body := `{"server_ip":["x"],"body":{"token_ids":[1,2,3,4],"block_size":4}}`
	mux.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, "/kv-cache/hit-rate", strings.NewReader(body)))
	if rr.Code != http.StatusInternalServerError {
		t.Fatalf("want 500 got %d", rr.Code)
	}
}

func TestHitRate_UnknownFieldsRejected(t *testing.T) {
	mux := newTestMux(t, &stubScore{})
	rr := httptest.NewRecorder()
	body := `{"server_ip":["x"],"surprise":1,"body":{"token_ids":[1],"block_size":1}}`
	mux.ServeHTTP(rr, httptest.NewRequest(http.MethodPost, "/kv-cache/hit-rate", bytes.NewBufferString(body)))
	if rr.Code != http.StatusBadRequest {
		t.Fatalf("want 400 got %d", rr.Code)
	}
}

func TestHitRate_BodyTooLarge(t *testing.T) {
	cfg := normalizeHTTPConfig(config.HTTPConfig{MaxHitRateBodyBytes: testLargeBodyLimit})
	mux := newMux(testr.New(t), &stubScore{}, cfg)
	rr := httptest.NewRecorder()
	body := `{"server_ip":["x"],"body":{"token_ids":[1],"block_size":1}}`
	req := httptest.NewRequest(http.MethodPost, "/kv-cache/hit-rate", strings.NewReader(body))
	mux.ServeHTTP(rr, req)
	if rr.Code != http.StatusRequestEntityTooLarge {
		t.Fatalf("want 413 got %d", rr.Code)
	}
}