mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-20 07:57:43 +08:00
feat(proxy): add proxy IP management and account routing
Add admin CRUD and connectivity checks for SOCKS5/SOCKS5H proxy nodes. Allow accounts to bind to a proxy, route DeepSeek requests through the selected node, and expose proxy management in the admin UI.
This commit is contained in:
@@ -13,6 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) {
|
||||
clients := c.requestClientsForAccount(acc)
|
||||
payload := map[string]any{
|
||||
"password": strings.TrimSpace(acc.Password),
|
||||
"device_id": "deepseek_to_api",
|
||||
@@ -27,7 +28,7 @@ func (c *Client) Login(ctx context.Context, acc config.Account) (string, error)
|
||||
} else {
|
||||
return "", errors.New("missing email/mobile")
|
||||
}
|
||||
resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload)
|
||||
resp, err := c.postJSON(ctx, clients.regular, DeepSeekLoginURL, BaseHeaders, payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -52,11 +53,12 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
|
||||
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
|
||||
if err != nil {
|
||||
config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
@@ -94,11 +96,12 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
attempts := 0
|
||||
refreshed := false
|
||||
for attempts < maxAttempts {
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
|
||||
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
|
||||
if err != nil {
|
||||
config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
|
||||
@@ -10,18 +10,20 @@ import (
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
trans "ds2api/internal/deepseek/transport"
|
||||
)
|
||||
|
||||
func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) {
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
headers["x-ds-pow-response"] = powResp
|
||||
captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
|
||||
attempts := 0
|
||||
for attempts < maxAttempts {
|
||||
resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
|
||||
resp, err := c.streamPost(ctx, clients.stream, DeepSeekCompletionURL, headers, payload)
|
||||
if err != nil {
|
||||
attempts++
|
||||
time.Sleep(time.Second)
|
||||
@@ -44,11 +46,12 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa
|
||||
return nil, errors.New("completion failed")
|
||||
}
|
||||
|
||||
func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) {
|
||||
func (c *Client) streamPost(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (*http.Response, error) {
|
||||
b, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clients := c.requestClientsFromContext(ctx)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -56,7 +59,7 @@ func (c *Client) streamPost(ctx context.Context, url string, headers map[string]
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
resp, err := c.stream.Do(req)
|
||||
resp, err := doer.Do(req)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err)
|
||||
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
@@ -66,7 +69,7 @@ func (c *Client) streamPost(ctx context.Context, url string, headers map[string]
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
return c.fallbackS.Do(req2)
|
||||
return clients.fallbackS.Do(req2)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
@@ -51,6 +51,7 @@ func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionI
|
||||
if strings.TrimSpace(sessionID) == "" || responseMessageID <= 0 {
|
||||
return nil, errors.New("missing continue identifiers")
|
||||
}
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
headers["x-ds-pow-response"] = powResp
|
||||
payload := map[string]any{
|
||||
@@ -60,7 +61,7 @@ func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionI
|
||||
}
|
||||
config.Logger.Info("[auto_continue] calling continue", "session_id", sessionID, "message_id", responseMessageID)
|
||||
captureSession := c.capture.Start("deepseek_continue", DeepSeekContinueURL, a.AccountID, payload)
|
||||
resp, err := c.streamPost(ctx, DeepSeekContinueURL, headers, payload)
|
||||
resp, err := c.streamPost(ctx, clients.stream, DeepSeekContinueURL, headers, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package deepseek
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
@@ -24,18 +25,22 @@ type Client struct {
|
||||
fallback *http.Client
|
||||
fallbackS *http.Client
|
||||
maxRetries int
|
||||
|
||||
proxyClientsMu sync.RWMutex
|
||||
proxyClients map[string]requestClients
|
||||
}
|
||||
|
||||
func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
|
||||
return &Client{
|
||||
Store: store,
|
||||
Auth: resolver,
|
||||
capture: devcapture.Global(),
|
||||
regular: trans.New(60 * time.Second),
|
||||
stream: trans.New(0),
|
||||
fallback: &http.Client{Timeout: 60 * time.Second},
|
||||
fallbackS: &http.Client{Timeout: 0},
|
||||
maxRetries: 3,
|
||||
Store: store,
|
||||
Auth: resolver,
|
||||
capture: devcapture.Global(),
|
||||
regular: trans.New(60 * time.Second),
|
||||
stream: trans.New(0),
|
||||
fallback: &http.Client{Timeout: 60 * time.Second},
|
||||
fallbackS: &http.Client{Timeout: 0},
|
||||
maxRetries: 3,
|
||||
proxyClients: map[string]requestClients{},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
clients := c.requestClientsFromContext(ctx)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -44,7 +45,7 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
resp, err = c.fallback.Do(req2)
|
||||
resp, err = clients.fallback.Do(req2)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -64,6 +65,7 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
|
||||
}
|
||||
|
||||
func (c *Client) getJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string) (map[string]any, int, error) {
|
||||
clients := c.requestClientsFromContext(ctx)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@@ -81,7 +83,7 @@ func (c *Client) getJSONWithStatus(ctx context.Context, doer trans.Doer, url str
|
||||
for k, v := range headers {
|
||||
req2.Header.Set(k, v)
|
||||
}
|
||||
resp, err = c.fallback.Do(req2)
|
||||
resp, err = clients.fallback.Do(req2)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@@ -36,6 +36,7 @@ func (c *Client) GetSessionCount(ctx context.Context, a *auth.RequestAuth, maxAt
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
|
||||
stats := &SessionStats{
|
||||
AccountID: a.AccountID,
|
||||
@@ -50,7 +51,7 @@ func (c *Client) GetSessionCount(ctx context.Context, a *auth.RequestAuth, maxAt
|
||||
// 构建请求 URL
|
||||
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
|
||||
|
||||
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||
resp, status, err := c.getJSONWithStatus(ctx, clients.regular, reqURL, headers)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[get_session_count] request error", "error", err, "account", a.AccountID)
|
||||
attempts++
|
||||
@@ -106,10 +107,11 @@ func (c *Client) GetSessionCount(ctx context.Context, a *auth.RequestAuth, maxAt
|
||||
|
||||
// GetSessionCountForToken 直接使用 token 获取会话数量(直通模式)
|
||||
func (c *Client) GetSessionCountForToken(ctx context.Context, token string) (*SessionStats, error) {
|
||||
clients := c.requestClientsFromContext(ctx)
|
||||
headers := c.authHeaders(token)
|
||||
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
|
||||
|
||||
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||
resp, status, err := c.getJSONWithStatus(ctx, clients.regular, reqURL, headers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -160,7 +162,7 @@ func (c *Client) GetSessionCountAll(ctx context.Context) []*SessionStats {
|
||||
// 如果没有 token,尝试登录获取
|
||||
if token == "" {
|
||||
var err error
|
||||
token, err = c.Login(ctx, acc)
|
||||
token, err = c.Login(auth.WithAuth(ctx, &auth.RequestAuth{AccountID: acc.Identifier(), Account: acc}), acc)
|
||||
if err != nil {
|
||||
results = append(results, &SessionStats{
|
||||
AccountID: accountID,
|
||||
@@ -171,7 +173,8 @@ func (c *Client) GetSessionCountAll(ctx context.Context) []*SessionStats {
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := c.GetSessionCountForToken(ctx, token)
|
||||
ctxWithAuth := auth.WithAuth(ctx, &auth.RequestAuth{AccountID: acc.Identifier(), Account: acc, DeepSeekToken: token})
|
||||
stats, err := c.GetSessionCountForToken(ctxWithAuth, token)
|
||||
if err != nil {
|
||||
results = append(results, &SessionStats{
|
||||
AccountID: accountID,
|
||||
@@ -190,6 +193,7 @@ func (c *Client) GetSessionCountAll(ctx context.Context) []*SessionStats {
|
||||
|
||||
// FetchSessionPage 获取会话列表(支持分页)
|
||||
func (c *Client) FetchSessionPage(ctx context.Context, a *auth.RequestAuth, cursor string) ([]SessionInfo, bool, error) {
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
|
||||
// 构建请求 URL
|
||||
@@ -200,7 +204,7 @@ func (c *Client) FetchSessionPage(ctx context.Context, a *auth.RequestAuth, curs
|
||||
}
|
||||
reqURL := DeepSeekFetchSessionURL + "?" + params.Encode()
|
||||
|
||||
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
|
||||
resp, status, err := c.getJSONWithStatus(ctx, clients.regular, reqURL, headers)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
@@ -22,6 +22,7 @@ func (c *Client) DeleteSession(ctx context.Context, a *auth.RequestAuth, session
|
||||
if maxAttempts <= 0 {
|
||||
maxAttempts = c.maxRetries
|
||||
}
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
|
||||
result := &DeleteSessionResult{
|
||||
SessionID: sessionID,
|
||||
@@ -42,7 +43,7 @@ func (c *Client) DeleteSession(ctx context.Context, a *auth.RequestAuth, session
|
||||
"chat_session_id": sessionID,
|
||||
}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekDeleteSessionURL, headers, payload)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[delete_session] request error", "error", err, "session_id", sessionID)
|
||||
attempts++
|
||||
@@ -81,6 +82,7 @@ func (c *Client) DeleteSession(ctx context.Context, a *auth.RequestAuth, session
|
||||
|
||||
// DeleteSessionForToken 直接使用 token 删除会话(直通模式)
|
||||
func (c *Client) DeleteSessionForToken(ctx context.Context, token string, sessionID string) (*DeleteSessionResult, error) {
|
||||
clients := c.requestClientsFromContext(ctx)
|
||||
result := &DeleteSessionResult{
|
||||
SessionID: sessionID,
|
||||
}
|
||||
@@ -95,7 +97,7 @@ func (c *Client) DeleteSessionForToken(ctx context.Context, token string, sessio
|
||||
"chat_session_id": sessionID,
|
||||
}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekDeleteSessionURL, headers, payload)
|
||||
if err != nil {
|
||||
result.ErrorMessage = err.Error()
|
||||
return result, err
|
||||
@@ -114,10 +116,11 @@ func (c *Client) DeleteSessionForToken(ctx context.Context, token string, sessio
|
||||
|
||||
// DeleteAllSessions 删除所有会话(谨慎使用)
|
||||
func (c *Client) DeleteAllSessions(ctx context.Context, a *auth.RequestAuth) error {
|
||||
clients := c.requestClientsForAuth(ctx, a)
|
||||
headers := c.authHeaders(a.DeepSeekToken)
|
||||
payload := map[string]any{}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[delete_all_sessions] request error", "error", err)
|
||||
return err
|
||||
@@ -135,10 +138,11 @@ func (c *Client) DeleteAllSessions(ctx context.Context, a *auth.RequestAuth) err
|
||||
|
||||
// DeleteAllSessionsForToken 直接使用 token 删除所有会话(直通模式)
|
||||
func (c *Client) DeleteAllSessionsForToken(ctx context.Context, token string) error {
|
||||
clients := c.requestClientsFromContext(ctx)
|
||||
headers := c.authHeaders(token)
|
||||
payload := map[string]any{}
|
||||
|
||||
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
|
||||
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[delete_all_sessions_for_token] request error", "error", err)
|
||||
return err
|
||||
|
||||
239
internal/deepseek/proxy.go
Normal file
239
internal/deepseek/proxy.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
"ds2api/internal/auth"
|
||||
"ds2api/internal/config"
|
||||
trans "ds2api/internal/deepseek/transport"
|
||||
)
|
||||
|
||||
type requestClients struct {
|
||||
regular trans.Doer
|
||||
stream trans.Doer
|
||||
fallback *http.Client
|
||||
fallbackS *http.Client
|
||||
}
|
||||
|
||||
type hostLookupFunc func(ctx context.Context, network, host string) ([]string, error)
|
||||
|
||||
var proxyConnectivityTestURL = "https://chat.deepseek.com/"
|
||||
|
||||
var defaultHostLookup hostLookupFunc = func(ctx context.Context, _ string, host string) ([]string, error) {
|
||||
return net.DefaultResolver.LookupHost(ctx, host)
|
||||
}
|
||||
|
||||
func proxyDialAddress(ctx context.Context, proxyType, address string, lookup hostLookupFunc) (string, error) {
|
||||
proxyType = strings.ToLower(strings.TrimSpace(proxyType))
|
||||
if proxyType != "socks5" {
|
||||
return address, nil
|
||||
}
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if net.ParseIP(host) != nil {
|
||||
return address, nil
|
||||
}
|
||||
if lookup == nil {
|
||||
lookup = defaultHostLookup
|
||||
}
|
||||
addrs, err := lookup(ctx, "ip", host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if len(addrs) == 0 {
|
||||
return "", fmt.Errorf("no ip address resolved for %s", host)
|
||||
}
|
||||
return net.JoinHostPort(addrs[0], port), nil
|
||||
}
|
||||
|
||||
func proxyCacheKey(proxyCfg config.Proxy) string {
|
||||
proxyCfg = config.NormalizeProxy(proxyCfg)
|
||||
return strings.Join([]string{
|
||||
proxyCfg.ID,
|
||||
proxyCfg.Type,
|
||||
strings.ToLower(proxyCfg.Host),
|
||||
strconv.Itoa(proxyCfg.Port),
|
||||
proxyCfg.Username,
|
||||
proxyCfg.Password,
|
||||
}, "|")
|
||||
}
|
||||
|
||||
func proxyDialContext(proxyCfg config.Proxy) (trans.DialContextFunc, error) {
|
||||
proxyCfg = config.NormalizeProxy(proxyCfg)
|
||||
var authCfg *proxy.Auth
|
||||
if proxyCfg.Username != "" || proxyCfg.Password != "" {
|
||||
authCfg = &proxy.Auth{User: proxyCfg.Username, Password: proxyCfg.Password}
|
||||
}
|
||||
forward := &net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}
|
||||
dialer, err := proxy.SOCKS5("tcp", net.JoinHostPort(proxyCfg.Host, strconv.Itoa(proxyCfg.Port)), authCfg, forward)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return func(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
target, err := proxyDialAddress(ctx, proxyCfg.Type, address, defaultHostLookup)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ctxDialer, ok := dialer.(proxy.ContextDialer); ok {
|
||||
return ctxDialer.DialContext(ctx, network, target)
|
||||
}
|
||||
return dialer.Dial(network, target)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) defaultRequestClients() requestClients {
|
||||
return requestClients{
|
||||
regular: c.regular,
|
||||
stream: c.stream,
|
||||
fallback: c.fallback,
|
||||
fallbackS: c.fallbackS,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) resolveProxyForAccount(acc config.Account) (config.Proxy, bool) {
|
||||
if c == nil || c.Store == nil {
|
||||
return config.Proxy{}, false
|
||||
}
|
||||
proxyID := strings.TrimSpace(acc.ProxyID)
|
||||
if proxyID == "" {
|
||||
return config.Proxy{}, false
|
||||
}
|
||||
snap := c.Store.Snapshot()
|
||||
for _, proxyCfg := range snap.Proxies {
|
||||
proxyCfg = config.NormalizeProxy(proxyCfg)
|
||||
if proxyCfg.ID == proxyID {
|
||||
return proxyCfg, true
|
||||
}
|
||||
}
|
||||
return config.Proxy{}, false
|
||||
}
|
||||
|
||||
func (c *Client) requestClientsFromContext(ctx context.Context) requestClients {
|
||||
if a, ok := auth.FromContext(ctx); ok {
|
||||
return c.requestClientsForAccount(a.Account)
|
||||
}
|
||||
return c.defaultRequestClients()
|
||||
}
|
||||
|
||||
func (c *Client) requestClientsForAuth(ctx context.Context, a *auth.RequestAuth) requestClients {
|
||||
if a != nil {
|
||||
return c.requestClientsForAccount(a.Account)
|
||||
}
|
||||
return c.requestClientsFromContext(ctx)
|
||||
}
|
||||
|
||||
func (c *Client) requestClientsForAccount(acc config.Account) requestClients {
|
||||
proxyCfg, ok := c.resolveProxyForAccount(acc)
|
||||
if !ok {
|
||||
return c.defaultRequestClients()
|
||||
}
|
||||
|
||||
key := proxyCacheKey(proxyCfg)
|
||||
c.proxyClientsMu.RLock()
|
||||
cached, ok := c.proxyClients[key]
|
||||
c.proxyClientsMu.RUnlock()
|
||||
if ok {
|
||||
return cached
|
||||
}
|
||||
|
||||
dialContext, err := proxyDialContext(proxyCfg)
|
||||
if err != nil {
|
||||
config.Logger.Warn("[proxy] build dialer failed", "proxy_id", proxyCfg.ID, "error", err)
|
||||
return c.defaultRequestClients()
|
||||
}
|
||||
|
||||
bundle := requestClients{
|
||||
regular: trans.NewWithDialContext(60*time.Second, dialContext),
|
||||
stream: trans.NewWithDialContext(0, dialContext),
|
||||
fallback: trans.NewFallbackClient(60*time.Second, dialContext),
|
||||
fallbackS: trans.NewFallbackClient(0, dialContext),
|
||||
}
|
||||
|
||||
c.proxyClientsMu.Lock()
|
||||
if c.proxyClients == nil {
|
||||
c.proxyClients = make(map[string]requestClients)
|
||||
}
|
||||
c.proxyClients[key] = bundle
|
||||
c.proxyClientsMu.Unlock()
|
||||
return bundle
|
||||
}
|
||||
|
||||
func applyProxyConnectivityHeaders(req *http.Request) {
|
||||
if req == nil {
|
||||
return
|
||||
}
|
||||
for key, value := range BaseHeaders {
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key == "" || value == "" {
|
||||
continue
|
||||
}
|
||||
req.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func proxyConnectivityStatus(statusCode int) (bool, string) {
|
||||
switch {
|
||||
case statusCode >= 200 && statusCode < 300:
|
||||
return true, fmt.Sprintf("代理可达,目标返回 HTTP %d", statusCode)
|
||||
case statusCode >= 300 && statusCode < 500:
|
||||
return true, fmt.Sprintf("代理可达,但目标返回 HTTP %d(可能是风控或挑战)", statusCode)
|
||||
default:
|
||||
return false, fmt.Sprintf("目标返回 HTTP %d", statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyConnectivity(ctx context.Context, proxyCfg config.Proxy) map[string]any {
|
||||
start := time.Now()
|
||||
proxyCfg = config.NormalizeProxy(proxyCfg)
|
||||
result := map[string]any{
|
||||
"success": false,
|
||||
"proxy_id": proxyCfg.ID,
|
||||
"proxy_type": proxyCfg.Type,
|
||||
"response_time": 0,
|
||||
}
|
||||
|
||||
if err := config.ValidateProxyConfig([]config.Proxy{proxyCfg}); err != nil {
|
||||
result["message"] = "代理配置无效: " + err.Error()
|
||||
return result
|
||||
}
|
||||
dialContext, err := proxyDialContext(proxyCfg)
|
||||
if err != nil {
|
||||
result["message"] = "代理拨号器初始化失败: " + err.Error()
|
||||
return result
|
||||
}
|
||||
|
||||
client := trans.NewFallbackClient(15*time.Second, dialContext)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxyConnectivityTestURL, nil)
|
||||
if err != nil {
|
||||
result["message"] = err.Error()
|
||||
return result
|
||||
}
|
||||
applyProxyConnectivityHeaders(req)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
result["response_time"] = int(time.Since(start).Milliseconds())
|
||||
if err != nil {
|
||||
result["message"] = err.Error()
|
||||
return result
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := resp.Body.Close(); closeErr != nil {
|
||||
config.Logger.Warn("[proxy] close response body failed", "proxy_id", proxyCfg.ID, "error", closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
result["status_code"] = resp.StatusCode
|
||||
result["success"], result["message"] = proxyConnectivityStatus(resp.StatusCode)
|
||||
return result
|
||||
}
|
||||
85
internal/deepseek/proxy_test.go
Normal file
85
internal/deepseek/proxy_test.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package deepseek
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProxyDialAddressUsesLocalResolutionForSocks5(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
resolved, err := proxyDialAddress(ctx, "socks5", "example.com:443", func(_ context.Context, network, host string) ([]string, error) {
|
||||
if network != "ip" {
|
||||
t.Fatalf("unexpected lookup network: %q", network)
|
||||
}
|
||||
if host != "example.com" {
|
||||
t.Fatalf("unexpected lookup host: %q", host)
|
||||
}
|
||||
return []string{"203.0.113.10"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("proxyDialAddress returned error: %v", err)
|
||||
}
|
||||
if resolved != "203.0.113.10:443" {
|
||||
t.Fatalf("expected locally resolved address, got %q", resolved)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyDialAddressKeepsHostnameForSocks5h(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
lookups := 0
|
||||
resolved, err := proxyDialAddress(ctx, "socks5h", "example.com:443", func(_ context.Context, network, host string) ([]string, error) {
|
||||
lookups++
|
||||
return []string{"203.0.113.10"}, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("proxyDialAddress returned error: %v", err)
|
||||
}
|
||||
if resolved != "example.com:443" {
|
||||
t.Fatalf("expected hostname preserved for remote DNS, got %q", resolved)
|
||||
}
|
||||
if lookups != 0 {
|
||||
t.Fatalf("expected no local DNS lookup for socks5h, got %d", lookups)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyProxyConnectivityHeadersUsesBaseHeaders(t *testing.T) {
|
||||
req, err := http.NewRequest(http.MethodGet, "https://chat.deepseek.com/", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("http.NewRequest returned error: %v", err)
|
||||
}
|
||||
|
||||
applyProxyConnectivityHeaders(req)
|
||||
|
||||
for key, want := range BaseHeaders {
|
||||
if got := req.Header.Get(key); got != want {
|
||||
t.Fatalf("expected header %q=%q, got %q", key, want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyConnectivityStatus(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
success bool
|
||||
wantText string
|
||||
}{
|
||||
{name: "ok", statusCode: 200, success: true, wantText: "HTTP 200"},
|
||||
{name: "challenge", statusCode: 403, success: true, wantText: "风控或挑战"},
|
||||
{name: "upstream error", statusCode: 502, success: false, wantText: "HTTP 502"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
success, message := proxyConnectivityStatus(tc.statusCode)
|
||||
if success != tc.success {
|
||||
t.Fatalf("expected success=%v, got %v", tc.success, success)
|
||||
}
|
||||
if message == "" || !strings.Contains(message, tc.wantText) {
|
||||
t.Fatalf("expected message to contain %q, got %q", tc.wantText, message)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,21 +15,33 @@ type Doer interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
type DialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
type Client struct {
|
||||
http *http.Client
|
||||
}
|
||||
|
||||
func New(timeout time.Duration) *Client {
|
||||
return NewWithDialContext(timeout, nil)
|
||||
}
|
||||
|
||||
func NewWithDialContext(timeout time.Duration, dialContext DialContextFunc) *Client {
|
||||
useEnvProxy := dialContext == nil
|
||||
if dialContext == nil {
|
||||
dialContext = (&net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}).DialContext
|
||||
}
|
||||
base := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
ForceAttemptHTTP2: false,
|
||||
MaxIdleConns: 200,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DialContext: (&net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}).DialContext,
|
||||
DialTLSContext: safariTLSDialer(),
|
||||
DialContext: dialContext,
|
||||
DialTLSContext: safariTLSDialer(dialContext),
|
||||
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
||||
}
|
||||
if useEnvProxy {
|
||||
base.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
return &Client{http: &http.Client{Timeout: timeout, Transport: base}}
|
||||
}
|
||||
|
||||
@@ -37,10 +49,31 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
|
||||
return c.http.Do(req)
|
||||
}
|
||||
|
||||
func safariTLSDialer() func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
var dialer net.Dialer
|
||||
func NewFallbackClient(timeout time.Duration, dialContext DialContextFunc) *http.Client {
|
||||
useEnvProxy := dialContext == nil
|
||||
if dialContext == nil {
|
||||
dialContext = (&net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}).DialContext
|
||||
}
|
||||
base := &http.Transport{
|
||||
ForceAttemptHTTP2: false,
|
||||
MaxIdleConns: 200,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
DialContext: dialContext,
|
||||
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
|
||||
}
|
||||
if useEnvProxy {
|
||||
base.Proxy = http.ProxyFromEnvironment
|
||||
}
|
||||
return &http.Client{Timeout: timeout, Transport: base}
|
||||
}
|
||||
|
||||
func safariTLSDialer(dialContext DialContextFunc) func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
if dialContext == nil {
|
||||
dialContext = (&net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}).DialContext
|
||||
}
|
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
plainConn, err := dialer.DialContext(ctx, network, addr)
|
||||
plainConn, err := dialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user