/*
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * openFuyao is licensed under Mulan PSL v2.
 * You can use this software according to the terms and conditions of the Mulan PSL v2.
 * You may obtain a copy of Mulan PSL v2 at:
 *          http://license.coscl.org.cn/MulanPSL2
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
 * See the Mulan PSL v2 for more details.
 */

package rdma

import (
	"bytes"
	"context"
	"encoding/json"
	"fmt"
	"io"
	"net/http"
	"os"
	"strings"
	"time"

	"github.com/openfuyao/weight-dispatcher/pkg/internal/errutil"
	sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
)

// ChunkClient is the minimum random-read abstraction used by direct and striped
// pull modes. A future RDMA client can implement the same interface.
type ChunkClient interface {
	Stat(ctx context.Context, endpoint, rootPath, relativePath string) (int64, error)
	ReadAt(ctx context.Context, endpoint, rootPath, relativePath string, offset, length int64) ([]byte, error)
}

// LocalChunkClient reads local files and is used for tests and local fallback.
type LocalChunkClient struct{}

// Stat returns the size of one local source file.
func (LocalChunkClient) Stat(_ context.Context, _, rootPath, relativePath string) (int64, error) {
	path := resolveSourceFilePath(rootPath, relativePath)
	info, err := os.Stat(path)
	if err != nil {
		return 0, errutil.Wrap(fmt.Sprintf("stat local chunk file %s", path), err)
	}
	return info.Size(), nil
}

// ReadAt reads one byte range from a local source file.
func (LocalChunkClient) ReadAt(_ context.Context, _, rootPath, relativePath string, offset, length int64) (_ []byte, err error) {
	path := resolveSourceFilePath(rootPath, relativePath)
	file, err := os.Open(path)
	if err != nil {
		return nil, errutil.Wrap(fmt.Sprintf("open local chunk file %s", path), err)
	}
	defer func() {
		err = mergeCloseError(err, file.Close(), fmt.Sprintf("close local chunk file %s", path))
	}()

	buffer := make([]byte, length)
	n, err := file.ReadAt(buffer, offset)
	if err != nil && int64(n) != length {
		return nil, errutil.Wrap(fmt.Sprintf("read local chunk file %s at offset %d", path, offset), err)
	}
	return buffer[:n], nil
}

// HTTPChunkClient calls node-agent export APIs over HTTP.
type HTTPChunkClient struct {
	client *http.Client
}

// SourceChunkRead carries one chunk payload and its optional CRC32C digest.
type SourceChunkRead struct {
	Payload []byte
	CRC32C  string
}

// NewHTTPChunkClient creates the HTTP export reader used by the RDMA adapter.
func NewHTTPChunkClient(client *http.Client) *HTTPChunkClient {
	if client == nil {
		client = newDataPlaneHTTPClient()
	}
	return &HTTPChunkClient{client: client}
}

func newDataPlaneHTTPClient() *http.Client {
	transport := http.DefaultTransport.(*http.Transport).Clone()
	transport.MaxIdleConns = 256
	transport.MaxIdleConnsPerHost = 64
	transport.IdleConnTimeout = 90 * time.Second
	return &http.Client{
		Timeout:   30 * time.Second,
		Transport: transport,
	}
}

// Stat returns the source file size via the node-agent export stat API.
func (c *HTTPChunkClient) Stat(ctx context.Context, endpoint, rootPath, relativePath string) (_ int64, err error) {
	baseURL, err := normalizeAgentEndpoint(endpoint)
	if err != nil {
		return 0, errutil.Wrap("normalize export stat endpoint", err)
	}
	body, err := json.Marshal(sharedtypes.StatExportRequest{RootPath: rootPath, RelativePath: relativePath})
	if err != nil {
		return 0, errutil.Wrap("marshal export stat request", err)
	}
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/v1/exports/stat", bytes.NewReader(body))
	if err != nil {
		return 0, errutil.Wrap("build export stat request", err)
	}
	req.Header.Set("Content-Type", "application/json")
	resp, err := c.client.Do(req)
	if err != nil {
		return 0, errutil.Wrap("post export stat request", err)
	}
	defer func() {
		err = mergeCloseError(err, resp.Body.Close(), "close stat export response body")
	}()
	if resp.StatusCode >= 300 {
		return 0, fmt.Errorf("stat export failed with status %d", resp.StatusCode)
	}
	var output sharedtypes.StatExportResponse
	if err := json.NewDecoder(resp.Body).Decode(&output); err != nil {
		return 0, errutil.Wrap("decode export stat response", err)
	}
	return output.SizeBytes, nil
}

// ReadAt reads one byte range from the source export API.
func (c *HTTPChunkClient) ReadAt(ctx context.Context, endpoint, rootPath, relativePath string, offset, length int64) ([]byte, error) {
	result, err := c.ReadAtWithCRC(ctx, endpoint, rootPath, relativePath, offset, length)
	if err != nil {
		return nil, errutil.Wrap(fmt.Sprintf("read export chunk %s at offset %d", relativePath, offset), err)
	}
	return result.Payload, nil
}

// ReadAtWithCRC reads one byte range and returns its optional CRC32C header.
func (c *HTTPChunkClient) ReadAtWithCRC(ctx context.Context, endpoint, rootPath, relativePath string, offset, length int64) (_ SourceChunkRead, err error) {
	baseURL, err := normalizeAgentEndpoint(endpoint)
	if err != nil {
		return SourceChunkRead{}, errutil.Wrap("normalize export read endpoint", err)
	}
	requestPayload := sharedtypes.ReadChunkRequest{RootPath: rootPath, RelativePath: relativePath, Offset: offset, Length: length}
	resp, err := c.postExportReadRequest(ctx, baseURL+"/v1/exports/readat/raw", requestPayload, "export raw read")
	if err != nil {
		return SourceChunkRead{}, err
	}
	defer func() {
		err = mergeCloseError(err, resp.Body.Close(), "close readat/raw response body")
	}()
	if resp.StatusCode < 300 {
		return readRawChunkPayload(resp, length)
	}
	return c.handleRawReadFailure(ctx, resp, baseURL, requestPayload)
}

// postExportReadRequest builds and posts one JSON export read request.
func (c *HTTPChunkClient) postExportReadRequest(
	ctx context.Context,
	endpoint string,
	requestPayload sharedtypes.ReadChunkRequest,
	action string,
) (*http.Response, error) {
	body, err := json.Marshal(requestPayload)
	if err != nil {
		return nil, errutil.Wrap("marshal "+action+" request", err)
	}
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
	if err != nil {
		return nil, errutil.Wrap("build "+action+" request", err)
	}
	req.Header.Set("Content-Type", "application/json")
	resp, err := c.client.Do(req)
	if err != nil {
		return nil, errutil.Wrap("post "+action+" request", err)
	}
	return resp, nil
}

// readRawChunkPayload reads and validates the raw chunk payload from node-agent.
func readRawChunkPayload(resp *http.Response, length int64) (SourceChunkRead, error) {
	payload, err := io.ReadAll(resp.Body)
	if err != nil {
		return SourceChunkRead{}, errutil.Wrap("read export raw response body", err)
	}
	if int64(len(payload)) != length {
		return SourceChunkRead{}, fmt.Errorf("read export short-read: expected %d bytes, got %d", length, len(payload))
	}
	return SourceChunkRead{Payload: payload, CRC32C: strings.TrimSpace(resp.Header.Get("X-Chunk-CRC32C"))}, nil
}

// handleRawReadFailure converts raw-read HTTP failures and falls back to the legacy JSON path.
func (c *HTTPChunkClient) handleRawReadFailure(
	ctx context.Context,
	resp *http.Response,
	baseURL string,
	requestPayload sharedtypes.ReadChunkRequest,
) (SourceChunkRead, error) {
	body, err := io.ReadAll(io.LimitReader(resp.Body, 4096))
	if err != nil {
		return SourceChunkRead{}, errutil.Wrap("read export raw error response body", err)
	}
	if resp.StatusCode != http.StatusNotFound {
		legacyPayload, legacyErr := c.readAtLegacyJSON(ctx, baseURL, requestPayload)
		if legacyErr == nil {
			return legacyPayload, nil
		}
	}
	if len(body) == 0 {
		return SourceChunkRead{}, fmt.Errorf("read export failed with status %d", resp.StatusCode)
	}
	return SourceChunkRead{}, fmt.Errorf("read export failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}

func (c *HTTPChunkClient) readAtLegacyJSON(ctx context.Context, baseURL string, requestPayload sharedtypes.ReadChunkRequest) (_ SourceChunkRead, err error) {
	body, err := json.Marshal(requestPayload)
	if err != nil {
		return SourceChunkRead{}, errutil.Wrap("marshal legacy export read request", err)
	}
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/v1/exports/readat", bytes.NewReader(body))
	if err != nil {
		return SourceChunkRead{}, errutil.Wrap("build legacy export read request", err)
	}
	req.Header.Set("Content-Type", "application/json")
	resp, err := c.client.Do(req)
	if err != nil {
		return SourceChunkRead{}, errutil.Wrap("post legacy export read request", err)
	}
	defer func() {
		err = mergeCloseError(err, resp.Body.Close(), "close legacy readat response body")
	}()
	if resp.StatusCode >= 300 {
		body, readErr := io.ReadAll(io.LimitReader(resp.Body, 4096))
		if readErr != nil {
			return SourceChunkRead{}, errutil.Wrap("read legacy export error response body", readErr)
		}
		if len(body) == 0 {
			return SourceChunkRead{}, fmt.Errorf("read export failed with status %d", resp.StatusCode)
		}
		return SourceChunkRead{}, fmt.Errorf("read export failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
	}
	var output sharedtypes.ReadChunkResponse
	if err := json.NewDecoder(resp.Body).Decode(&output); err != nil {
		return SourceChunkRead{}, errutil.Wrap("decode legacy export read response", err)
	}
	return SourceChunkRead{Payload: output.Data, CRC32C: output.CRC32C}, nil
}