* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* openFuyao is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
package rdma
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strings"
"time"
"github.com/openfuyao/weight-dispatcher/pkg/internal/errutil"
sharedtypes "github.com/openfuyao/weight-dispatcher/pkg/types"
)
type ChunkClient interface {
Stat(ctx context.Context, endpoint, rootPath, relativePath string) (int64, error)
ReadAt(ctx context.Context, endpoint, rootPath, relativePath string, offset, length int64) ([]byte, error)
}
type LocalChunkClient struct{}
func (LocalChunkClient) Stat(_ context.Context, _, rootPath, relativePath string) (int64, error) {
path := resolveSourceFilePath(rootPath, relativePath)
info, err := os.Stat(path)
if err != nil {
return 0, errutil.Wrap(fmt.Sprintf("stat local chunk file %s", path), err)
}
return info.Size(), nil
}
func (LocalChunkClient) ReadAt(_ context.Context, _, rootPath, relativePath string, offset, length int64) (_ []byte, err error) {
path := resolveSourceFilePath(rootPath, relativePath)
file, err := os.Open(path)
if err != nil {
return nil, errutil.Wrap(fmt.Sprintf("open local chunk file %s", path), err)
}
defer func() {
err = mergeCloseError(err, file.Close(), fmt.Sprintf("close local chunk file %s", path))
}()
buffer := make([]byte, length)
n, err := file.ReadAt(buffer, offset)
if err != nil && int64(n) != length {
return nil, errutil.Wrap(fmt.Sprintf("read local chunk file %s at offset %d", path, offset), err)
}
return buffer[:n], nil
}
type HTTPChunkClient struct {
client *http.Client
}
type SourceChunkRead struct {
Payload []byte
CRC32C string
}
func NewHTTPChunkClient(client *http.Client) *HTTPChunkClient {
if client == nil {
client = newDataPlaneHTTPClient()
}
return &HTTPChunkClient{client: client}
}
func newDataPlaneHTTPClient() *http.Client {
transport := http.DefaultTransport.(*http.Transport).Clone()
transport.MaxIdleConns = 256
transport.MaxIdleConnsPerHost = 64
transport.IdleConnTimeout = 90 * time.Second
return &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}
}
func (c *HTTPChunkClient) Stat(ctx context.Context, endpoint, rootPath, relativePath string) (_ int64, err error) {
baseURL, err := normalizeAgentEndpoint(endpoint)
if err != nil {
return 0, errutil.Wrap("normalize export stat endpoint", err)
}
body, err := json.Marshal(sharedtypes.StatExportRequest{RootPath: rootPath, RelativePath: relativePath})
if err != nil {
return 0, errutil.Wrap("marshal export stat request", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/v1/exports/stat", bytes.NewReader(body))
if err != nil {
return 0, errutil.Wrap("build export stat request", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return 0, errutil.Wrap("post export stat request", err)
}
defer func() {
err = mergeCloseError(err, resp.Body.Close(), "close stat export response body")
}()
if resp.StatusCode >= 300 {
return 0, fmt.Errorf("stat export failed with status %d", resp.StatusCode)
}
var output sharedtypes.StatExportResponse
if err := json.NewDecoder(resp.Body).Decode(&output); err != nil {
return 0, errutil.Wrap("decode export stat response", err)
}
return output.SizeBytes, nil
}
func (c *HTTPChunkClient) ReadAt(ctx context.Context, endpoint, rootPath, relativePath string, offset, length int64) ([]byte, error) {
result, err := c.ReadAtWithCRC(ctx, endpoint, rootPath, relativePath, offset, length)
if err != nil {
return nil, errutil.Wrap(fmt.Sprintf("read export chunk %s at offset %d", relativePath, offset), err)
}
return result.Payload, nil
}
func (c *HTTPChunkClient) ReadAtWithCRC(ctx context.Context, endpoint, rootPath, relativePath string, offset, length int64) (_ SourceChunkRead, err error) {
baseURL, err := normalizeAgentEndpoint(endpoint)
if err != nil {
return SourceChunkRead{}, errutil.Wrap("normalize export read endpoint", err)
}
requestPayload := sharedtypes.ReadChunkRequest{RootPath: rootPath, RelativePath: relativePath, Offset: offset, Length: length}
resp, err := c.postExportReadRequest(ctx, baseURL+"/v1/exports/readat/raw", requestPayload, "export raw read")
if err != nil {
return SourceChunkRead{}, err
}
defer func() {
err = mergeCloseError(err, resp.Body.Close(), "close readat/raw response body")
}()
if resp.StatusCode < 300 {
return readRawChunkPayload(resp, length)
}
return c.handleRawReadFailure(ctx, resp, baseURL, requestPayload)
}
func (c *HTTPChunkClient) postExportReadRequest(
ctx context.Context,
endpoint string,
requestPayload sharedtypes.ReadChunkRequest,
action string,
) (*http.Response, error) {
body, err := json.Marshal(requestPayload)
if err != nil {
return nil, errutil.Wrap("marshal "+action+" request", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, errutil.Wrap("build "+action+" request", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return nil, errutil.Wrap("post "+action+" request", err)
}
return resp, nil
}
func readRawChunkPayload(resp *http.Response, length int64) (SourceChunkRead, error) {
payload, err := io.ReadAll(resp.Body)
if err != nil {
return SourceChunkRead{}, errutil.Wrap("read export raw response body", err)
}
if int64(len(payload)) != length {
return SourceChunkRead{}, fmt.Errorf("read export short-read: expected %d bytes, got %d", length, len(payload))
}
return SourceChunkRead{Payload: payload, CRC32C: strings.TrimSpace(resp.Header.Get("X-Chunk-CRC32C"))}, nil
}
func (c *HTTPChunkClient) handleRawReadFailure(
ctx context.Context,
resp *http.Response,
baseURL string,
requestPayload sharedtypes.ReadChunkRequest,
) (SourceChunkRead, error) {
body, err := io.ReadAll(io.LimitReader(resp.Body, 4096))
if err != nil {
return SourceChunkRead{}, errutil.Wrap("read export raw error response body", err)
}
if resp.StatusCode != http.StatusNotFound {
legacyPayload, legacyErr := c.readAtLegacyJSON(ctx, baseURL, requestPayload)
if legacyErr == nil {
return legacyPayload, nil
}
}
if len(body) == 0 {
return SourceChunkRead{}, fmt.Errorf("read export failed with status %d", resp.StatusCode)
}
return SourceChunkRead{}, fmt.Errorf("read export failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
func (c *HTTPChunkClient) readAtLegacyJSON(ctx context.Context, baseURL string, requestPayload sharedtypes.ReadChunkRequest) (_ SourceChunkRead, err error) {
body, err := json.Marshal(requestPayload)
if err != nil {
return SourceChunkRead{}, errutil.Wrap("marshal legacy export read request", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/v1/exports/readat", bytes.NewReader(body))
if err != nil {
return SourceChunkRead{}, errutil.Wrap("build legacy export read request", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.client.Do(req)
if err != nil {
return SourceChunkRead{}, errutil.Wrap("post legacy export read request", err)
}
defer func() {
err = mergeCloseError(err, resp.Body.Close(), "close legacy readat response body")
}()
if resp.StatusCode >= 300 {
body, readErr := io.ReadAll(io.LimitReader(resp.Body, 4096))
if readErr != nil {
return SourceChunkRead{}, errutil.Wrap("read legacy export error response body", readErr)
}
if len(body) == 0 {
return SourceChunkRead{}, fmt.Errorf("read export failed with status %d", resp.StatusCode)
}
return SourceChunkRead{}, fmt.Errorf("read export failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var output sharedtypes.ReadChunkResponse
if err := json.NewDecoder(resp.Body).Decode(&output); err != nil {
return SourceChunkRead{}, errutil.Wrap("decode legacy export read response", err)
}
return SourceChunkRead{Payload: output.Data, CRC32C: output.CRC32C}, nil
}