feat(proxy): add proxy IP management and account routing

Add admin CRUD and connectivity checks for SOCKS5/SOCKS5H proxy nodes.

Allow accounts to bind to a proxy, route DeepSeek requests through the selected node, and expose proxy management in the admin UI.
This commit is contained in:
Jason.li
2026-04-07 02:05:25 +08:00
parent 1c95942e5d
commit 8ae2ea10c8
30 changed files with 1675 additions and 51 deletions

View File

@@ -26,9 +26,15 @@ func RegisterRoutes(r chi.Router, h *Handler) {
pr.Get("/config/export", h.configExport)
pr.Post("/keys", h.addKey)
pr.Delete("/keys/{key}", h.deleteKey)
pr.Get("/proxies", h.listProxies)
pr.Post("/proxies", h.addProxy)
pr.Put("/proxies/{proxyID}", h.updateProxy)
pr.Delete("/proxies/{proxyID}", h.deleteProxy)
pr.Post("/proxies/test", h.testProxy)
pr.Get("/accounts", h.listAccounts)
pr.Post("/accounts", h.addAccount)
pr.Delete("/accounts/{identifier}", h.deleteAccount)
pr.Put("/accounts/{identifier}/proxy", h.updateAccountProxy)
pr.Get("/queue/status", h.queueStatus)
pr.Post("/accounts/test", h.testSingleAccount)
pr.Post("/accounts/test-all", h.testAllAccounts)

View File

@@ -68,6 +68,7 @@ func (h *Handler) listAccounts(w http.ResponseWriter, r *http.Request) {
"identifier": acc.Identifier(),
"email": acc.Email,
"mobile": acc.Mobile,
"proxy_id": acc.ProxyID,
"has_password": acc.Password != "",
"has_token": token != "",
"token_preview": preview,
@@ -86,6 +87,11 @@ func (h *Handler) addAccount(w http.ResponseWriter, r *http.Request) {
return
}
err := h.Store.Update(func(c *config.Config) error {
if acc.ProxyID != "" {
if _, ok := findProxyByID(*c, acc.ProxyID); !ok {
return fmt.Errorf("代理不存在")
}
}
mobileKey := config.CanonicalMobileKey(acc.Mobile)
for _, a := range c.Accounts {
if acc.Email != "" && a.Email == acc.Email {

View File

@@ -115,10 +115,11 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
result["message"] = "登录成功但写入运行时 token 失败: " + err.Error()
return result
}
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token}
sessionID, err := h.DS.CreateSession(ctx, authCtx, 1)
authCtx := &authn.RequestAuth{UseConfigToken: false, DeepSeekToken: token, AccountID: identifier, Account: acc}
proxyCtx := authn.WithAuth(ctx, authCtx)
sessionID, err := h.DS.CreateSession(proxyCtx, authCtx, 1)
if err != nil {
newToken, loginErr := h.DS.Login(ctx, acc)
newToken, loginErr := h.DS.Login(proxyCtx, acc)
if loginErr != nil {
result["message"] = "创建会话失败: " + err.Error()
return result
@@ -129,7 +130,7 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
result["message"] = "刷新 token 成功但写入运行时 token 失败: " + err.Error()
return result
}
sessionID, err = h.DS.CreateSession(ctx, authCtx, 1)
sessionID, err = h.DS.CreateSession(proxyCtx, authCtx, 1)
if err != nil {
result["message"] = "创建会话失败: " + err.Error()
return result
@@ -137,7 +138,7 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
}
// 获取会话数量
sessionStats, sessionErr := h.DS.GetSessionCountForToken(ctx, token)
sessionStats, sessionErr := h.DS.GetSessionCountForToken(proxyCtx, token)
if sessionErr == nil && sessionStats != nil {
result["session_count"] = sessionStats.FirstPageCount
}
@@ -153,13 +154,13 @@ func (h *Handler) testAccount(ctx context.Context, acc config.Account, model, me
thinking, search = false, false
}
_ = search
pow, err := h.DS.GetPow(ctx, authCtx, 1)
pow, err := h.DS.GetPow(proxyCtx, authCtx, 1)
if err != nil {
result["message"] = "获取 PoW 失败: " + err.Error()
return result
}
payload := map[string]any{"chat_session_id": sessionID, "prompt": deepseek.MessagesPrepare([]map[string]any{{"role": "user", "content": message}}), "ref_file_ids": []any{}, "thinking_enabled": thinking, "search_enabled": search}
resp, err := h.DS.CallCompletion(ctx, authCtx, payload, pow, 1)
resp, err := h.DS.CallCompletion(proxyCtx, authCtx, payload, pow, 1)
if err != nil {
result["message"] = "请求失败: " + err.Error()
return result
@@ -244,25 +245,29 @@ func (h *Handler) deleteAllSessions(w http.ResponseWriter, r *http.Request) {
}
// 每次先登录刷新一次 token避免使用过期 token。
token, err := h.DS.Login(r.Context(), acc)
authCtx := &authn.RequestAuth{UseConfigToken: false, AccountID: acc.Identifier(), Account: acc}
proxyCtx := authn.WithAuth(r.Context(), authCtx)
token, err := h.DS.Login(proxyCtx, acc)
if err != nil {
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "登录失败: " + err.Error()})
return
}
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
authCtx.DeepSeekToken = token
// 删除所有会话
err = h.DS.DeleteAllSessionsForToken(r.Context(), token)
err = h.DS.DeleteAllSessionsForToken(proxyCtx, token)
if err != nil {
// token 可能过期,尝试重新登录并重试一次
newToken, loginErr := h.DS.Login(r.Context(), acc)
newToken, loginErr := h.DS.Login(proxyCtx, acc)
if loginErr != nil {
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "删除失败: " + err.Error()})
return
}
token = newToken
_ = h.Store.UpdateAccountToken(acc.Identifier(), token)
if retryErr := h.DS.DeleteAllSessionsForToken(r.Context(), token); retryErr != nil {
authCtx.DeepSeekToken = token
if retryErr := h.DS.DeleteAllSessionsForToken(proxyCtx, token); retryErr != nil {
writeJSON(w, http.StatusOK, map[string]any{"success": false, "message": "删除失败: " + retryErr.Error()})
return
}

View File

@@ -3,6 +3,8 @@ package admin
import (
"net/http"
"strings"
"ds2api/internal/config"
)
func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
@@ -10,6 +12,7 @@ func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
safe := map[string]any{
"keys": snap.Keys,
"accounts": []map[string]any{},
"proxies": []map[string]any{},
"env_backed": h.Store.IsEnvBacked(),
"env_source_present": h.Store.HasEnvConfigSource(),
"env_writeback_enabled": h.Store.IsEnvWritebackEnabled(),
@@ -36,12 +39,27 @@ func (h *Handler) getConfig(w http.ResponseWriter, _ *http.Request) {
"identifier": acc.Identifier(),
"email": acc.Email,
"mobile": acc.Mobile,
"proxy_id": acc.ProxyID,
"has_password": strings.TrimSpace(acc.Password) != "",
"has_token": token != "",
"token_preview": preview,
})
}
safe["accounts"] = accounts
proxies := make([]map[string]any, 0, len(snap.Proxies))
for _, proxy := range snap.Proxies {
proxy = config.NormalizeProxy(proxy)
proxies = append(proxies, map[string]any{
"id": proxy.ID,
"name": proxy.Name,
"type": proxy.Type,
"host": proxy.Host,
"port": proxy.Port,
"username": proxy.Username,
"has_password": strings.TrimSpace(proxy.Password) != "",
})
}
safe["proxies"] = proxies
writeJSON(w, http.StatusOK, safe)
}

View File

@@ -0,0 +1,189 @@
package admin
import (
"context"
"encoding/json"
"net/http"
"net/url"
"strings"
"github.com/go-chi/chi/v5"
"ds2api/internal/config"
"ds2api/internal/deepseek"
)
var proxyConnectivityTester = func(ctx context.Context, proxy config.Proxy) map[string]any {
return deepseek.TestProxyConnectivity(ctx, proxy)
}
func validateProxyMutation(cfg *config.Config) error {
if cfg == nil {
return nil
}
if err := config.ValidateProxyConfig(cfg.Proxies); err != nil {
return err
}
return config.ValidateAccountProxyReferences(cfg.Accounts, cfg.Proxies)
}
func (h *Handler) listProxies(w http.ResponseWriter, _ *http.Request) {
proxies := h.Store.Snapshot().Proxies
items := make([]map[string]any, 0, len(proxies))
for _, proxy := range proxies {
proxy = config.NormalizeProxy(proxy)
items = append(items, map[string]any{
"id": proxy.ID,
"name": proxy.Name,
"type": proxy.Type,
"host": proxy.Host,
"port": proxy.Port,
"username": proxy.Username,
"has_password": strings.TrimSpace(proxy.Password) != "",
})
}
writeJSON(w, http.StatusOK, map[string]any{"items": items, "total": len(items)})
}
func (h *Handler) addProxy(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
proxy := toProxy(req)
err := h.Store.Update(func(c *config.Config) error {
c.Proxies = append(c.Proxies, proxy)
return validateProxyMutation(c)
})
if err != nil {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{"success": true, "proxy": proxy})
}
func (h *Handler) updateProxy(w http.ResponseWriter, r *http.Request) {
proxyID := chi.URLParam(r, "proxyID")
if decoded, err := url.PathUnescape(proxyID); err == nil {
proxyID = decoded
}
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
proxy := toProxy(req)
proxy.ID = strings.TrimSpace(proxyID)
err := h.Store.Update(func(c *config.Config) error {
for i, existing := range c.Proxies {
existing = config.NormalizeProxy(existing)
if existing.ID != proxy.ID {
continue
}
if proxy.Password == "" {
proxy.Password = existing.Password
}
c.Proxies[i] = proxy
return validateProxyMutation(c)
}
return newRequestError("代理不存在")
})
if err != nil {
if detail, ok := requestErrorDetail(err); ok {
writeJSON(w, http.StatusNotFound, map[string]any{"detail": detail})
return
}
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{"success": true, "proxy": proxy})
}
func (h *Handler) deleteProxy(w http.ResponseWriter, r *http.Request) {
proxyID := chi.URLParam(r, "proxyID")
if decoded, err := url.PathUnescape(proxyID); err == nil {
proxyID = decoded
}
err := h.Store.Update(func(c *config.Config) error {
idx := -1
for i, existing := range c.Proxies {
existing = config.NormalizeProxy(existing)
if existing.ID == strings.TrimSpace(proxyID) {
idx = i
break
}
}
if idx < 0 {
return newRequestError("代理不存在")
}
c.Proxies = append(c.Proxies[:idx], c.Proxies[idx+1:]...)
for i := range c.Accounts {
if strings.TrimSpace(c.Accounts[i].ProxyID) == strings.TrimSpace(proxyID) {
c.Accounts[i].ProxyID = ""
}
}
return validateProxyMutation(c)
})
if err != nil {
if detail, ok := requestErrorDetail(err); ok {
writeJSON(w, http.StatusNotFound, map[string]any{"detail": detail})
return
}
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
writeJSON(w, http.StatusOK, map[string]any{"success": true})
}
func (h *Handler) testProxy(w http.ResponseWriter, r *http.Request) {
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
proxyID := fieldString(req, "proxy_id")
var proxy config.Proxy
if proxyID != "" {
var ok bool
proxy, ok = findProxyByID(h.Store.Snapshot(), proxyID)
if !ok {
writeJSON(w, http.StatusNotFound, map[string]any{"detail": "代理不存在"})
return
}
} else {
proxy = toProxy(req)
}
result := proxyConnectivityTester(r.Context(), proxy)
writeJSON(w, http.StatusOK, result)
}
func (h *Handler) updateAccountProxy(w http.ResponseWriter, r *http.Request) {
identifier := chi.URLParam(r, "identifier")
if decoded, err := url.PathUnescape(identifier); err == nil {
identifier = decoded
}
var req map[string]any
_ = json.NewDecoder(r.Body).Decode(&req)
proxyID := fieldString(req, "proxy_id")
err := h.Store.Update(func(c *config.Config) error {
if proxyID != "" {
if _, ok := findProxyByID(*c, proxyID); !ok {
return newRequestError("代理不存在")
}
}
for i, acc := range c.Accounts {
if !accountMatchesIdentifier(acc, identifier) {
continue
}
c.Accounts[i].ProxyID = proxyID
return validateProxyMutation(c)
}
return newRequestError("账号不存在")
})
if err != nil {
if detail, ok := requestErrorDetail(err); ok {
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": detail})
return
}
writeJSON(w, http.StatusBadRequest, map[string]any{"detail": err.Error()})
return
}
h.Pool.Reset()
writeJSON(w, http.StatusOK, map[string]any{"success": true, "proxy_id": proxyID})
}

View File

@@ -0,0 +1,193 @@
package admin
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/go-chi/chi/v5"
"ds2api/internal/account"
"ds2api/internal/config"
)
func newAdminProxyTestHandler(t *testing.T, raw string) *Handler {
t.Helper()
t.Setenv("DS2API_CONFIG_JSON", raw)
store := config.LoadStore()
return &Handler{
Store: store,
Pool: account.NewPool(store),
}
}
func TestAddProxyPersistsNormalizedProxy(t *testing.T) {
h := newAdminProxyTestHandler(t, `{"accounts":[]}`)
r := chi.NewRouter()
r.Post("/admin/proxies", h.addProxy)
req := httptest.NewRequest(http.MethodPost, "/admin/proxies", bytes.NewBufferString(`{
"name":" HK Exit ",
"type":" SOCKS5H ",
"host":" 127.0.0.1 ",
"port":1081,
"username":" user ",
"password":" pass "
}`))
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
}
proxies := h.Store.Snapshot().Proxies
if len(proxies) != 1 {
t.Fatalf("expected 1 proxy, got %d", len(proxies))
}
if proxies[0].Name != "HK Exit" {
t.Fatalf("unexpected proxy name: %#v", proxies[0])
}
if proxies[0].Type != "socks5h" {
t.Fatalf("unexpected proxy type: %#v", proxies[0])
}
if proxies[0].Username != "user" || proxies[0].Password != "pass" {
t.Fatalf("expected trimmed credentials, got %#v", proxies[0])
}
if proxies[0].ID == "" {
t.Fatalf("expected generated proxy id, got %#v", proxies[0])
}
}
func TestAddProxyDoesNotFailOnUnrelatedInvalidRuntimeConfig(t *testing.T) {
router := newHTTPAdminHarness(t, `{
"keys":["k1"],
"runtime":{
"account_max_inflight":8,
"global_max_inflight":4
}
}`, &testingDSMock{})
rec := httptest.NewRecorder()
router.ServeHTTP(rec, adminReq(http.MethodPost, "/proxies", []byte(`{
"name":"HK Exit",
"type":"socks5h",
"host":"127.0.0.1",
"port":1080
}`)))
if rec.Code != http.StatusOK {
t.Fatalf("expected add proxy success despite unrelated runtime issue, got %d body=%s", rec.Code, rec.Body.String())
}
readRec := httptest.NewRecorder()
router.ServeHTTP(readRec, adminReq(http.MethodGet, "/config", nil))
if readRec.Code != http.StatusOK {
t.Fatalf("config read status=%d body=%s", readRec.Code, readRec.Body.String())
}
var payload map[string]any
if err := json.Unmarshal(readRec.Body.Bytes(), &payload); err != nil {
t.Fatalf("decode config response: %v", err)
}
proxies, _ := payload["proxies"].([]any)
if len(proxies) != 1 {
t.Fatalf("expected proxy to be persisted, got %#v", payload["proxies"])
}
}
func TestDeleteProxyClearsAssignedAccountProxyID(t *testing.T) {
h := newAdminProxyTestHandler(t, `{
"proxies":[{"id":"proxy-1","name":"Node 1","type":"socks5","host":"127.0.0.1","port":1080}],
"accounts":[{"email":"u@example.com","password":"pwd","proxy_id":"proxy-1"}]
}`)
r := chi.NewRouter()
r.Delete("/admin/proxies/{proxyID}", h.deleteProxy)
req := httptest.NewRequest(http.MethodDelete, "/admin/proxies/proxy-1", nil)
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
}
snap := h.Store.Snapshot()
if len(snap.Proxies) != 0 {
t.Fatalf("expected proxy removed, got %#v", snap.Proxies)
}
if len(snap.Accounts) != 1 {
t.Fatalf("expected account kept, got %#v", snap.Accounts)
}
if snap.Accounts[0].ProxyID != "" {
t.Fatalf("expected proxy assignment cleared, got %#v", snap.Accounts[0])
}
}
func TestUpdateAccountProxyAssignsProxyID(t *testing.T) {
h := newAdminProxyTestHandler(t, `{
"proxies":[{"id":"proxy-1","name":"Node 1","type":"socks5h","host":"127.0.0.1","port":1080}],
"accounts":[{"email":"u@example.com","password":"pwd"}]
}`)
r := chi.NewRouter()
r.Put("/admin/accounts/{identifier}/proxy", h.updateAccountProxy)
req := httptest.NewRequest(http.MethodPut, "/admin/accounts/u@example.com/proxy", bytes.NewBufferString(`{"proxy_id":"proxy-1"}`))
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
}
acc, ok := h.Store.FindAccount("u@example.com")
if !ok {
t.Fatal("expected account")
}
if acc.ProxyID != "proxy-1" {
t.Fatalf("expected proxy assigned, got %#v", acc)
}
}
func TestTestProxyUsesStoredProxy(t *testing.T) {
h := newAdminProxyTestHandler(t, `{
"proxies":[{"id":"proxy-1","name":"Node 1","type":"socks5h","host":"127.0.0.1","port":1080}]
}`)
original := proxyConnectivityTester
defer func() { proxyConnectivityTester = original }()
var got config.Proxy
proxyConnectivityTester = func(_ context.Context, proxy config.Proxy) map[string]any {
got = proxy
return map[string]any{
"success": true,
"proxy_id": proxy.ID,
"proxy_type": proxy.Type,
"response_time": 12,
}
}
r := chi.NewRouter()
r.Post("/admin/proxies/test", h.testProxy)
req := httptest.NewRequest(http.MethodPost, "/admin/proxies/test", bytes.NewBufferString(`{"proxy_id":"proxy-1"}`))
rec := httptest.NewRecorder()
r.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d body=%s", rec.Code, rec.Body.String())
}
if got.ID != "proxy-1" || got.Type != "socks5h" {
t.Fatalf("expected stored proxy passed to tester, got %#v", got)
}
var payload map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &payload); err != nil {
t.Fatalf("decode response: %v", err)
}
if ok, _ := payload["success"].(bool); !ok {
t.Fatalf("expected success payload, got %#v", payload)
}
}

View File

@@ -65,6 +65,7 @@ func toAccount(m map[string]any) config.Account {
Email: email,
Mobile: mobile,
Password: fieldString(m, "password"),
ProxyID: fieldString(m, "proxy_id"),
}
}
@@ -100,9 +101,36 @@ func accountMatchesIdentifier(acc config.Account, identifier string) bool {
func normalizeAccountForStorage(acc config.Account) config.Account {
acc.Email = strings.TrimSpace(acc.Email)
acc.Mobile = config.NormalizeMobileForStorage(acc.Mobile)
acc.ProxyID = strings.TrimSpace(acc.ProxyID)
return acc
}
func toProxy(m map[string]any) config.Proxy {
return config.NormalizeProxy(config.Proxy{
ID: fieldString(m, "id"),
Name: fieldString(m, "name"),
Type: fieldString(m, "type"),
Host: fieldString(m, "host"),
Port: intFrom(m["port"]),
Username: fieldString(m, "username"),
Password: fieldString(m, "password"),
})
}
func findProxyByID(c config.Config, proxyID string) (config.Proxy, bool) {
id := strings.TrimSpace(proxyID)
if id == "" {
return config.Proxy{}, false
}
for _, proxy := range c.Proxies {
proxy = config.NormalizeProxy(proxy)
if proxy.ID == id {
return proxy, true
}
}
return config.Proxy{}, false
}
func accountDedupeKey(acc config.Account) string {
if email := strings.TrimSpace(acc.Email); email != "" {
return "email:" + email

View File

@@ -20,6 +20,9 @@ func (c Config) MarshalJSON() ([]byte, error) {
if len(c.Accounts) > 0 {
m["accounts"] = c.Accounts
}
if len(c.Proxies) > 0 {
m["proxies"] = c.Proxies
}
if len(c.ClaudeMapping) > 0 {
m["claude_mapping"] = c.ClaudeMapping
}
@@ -70,6 +73,10 @@ func (c *Config) UnmarshalJSON(b []byte) error {
if err := json.Unmarshal(v, &c.Accounts); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "proxies":
if err := json.Unmarshal(v, &c.Proxies); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
}
case "claude_mapping":
if err := json.Unmarshal(v, &c.ClaudeMapping); err != nil {
return fmt.Errorf("invalid field %q: %w", k, err)
@@ -130,6 +137,7 @@ func (c Config) Clone() Config {
clone := Config{
Keys: slices.Clone(c.Keys),
Accounts: slices.Clone(c.Accounts),
Proxies: slices.Clone(c.Proxies),
ClaudeMapping: cloneStringMap(c.ClaudeMapping),
ClaudeModelMap: cloneStringMap(c.ClaudeModelMap),
ModelAliases: cloneStringMap(c.ModelAliases),

View File

@@ -1,8 +1,16 @@
package config
import (
"crypto/sha1"
"encoding/hex"
"fmt"
"strings"
)
type Config struct {
Keys []string `json:"keys,omitempty"`
Accounts []Account `json:"accounts,omitempty"`
Proxies []Proxy `json:"proxies,omitempty"`
ClaudeMapping map[string]string `json:"claude_mapping,omitempty"`
ClaudeModelMap map[string]string `json:"claude_model_mapping,omitempty"`
ModelAliases map[string]string `json:"model_aliases,omitempty"`
@@ -22,6 +30,38 @@ type Account struct {
Mobile string `json:"mobile,omitempty"`
Password string `json:"password,omitempty"`
Token string `json:"token,omitempty"`
ProxyID string `json:"proxy_id,omitempty"`
}
type Proxy struct {
ID string `json:"id,omitempty"`
Name string `json:"name,omitempty"`
Type string `json:"type,omitempty"`
Host string `json:"host,omitempty"`
Port int `json:"port,omitempty"`
Username string `json:"username,omitempty"`
Password string `json:"password,omitempty"`
}
func NormalizeProxy(p Proxy) Proxy {
p.ID = strings.TrimSpace(p.ID)
p.Name = strings.TrimSpace(p.Name)
p.Type = strings.ToLower(strings.TrimSpace(p.Type))
p.Host = strings.TrimSpace(p.Host)
p.Username = strings.TrimSpace(p.Username)
p.Password = strings.TrimSpace(p.Password)
if p.ID == "" {
p.ID = StableProxyID(p)
}
if p.Name == "" && p.Host != "" && p.Port > 0 {
p.Name = fmt.Sprintf("%s:%d", p.Host, p.Port)
}
return p
}
func StableProxyID(p Proxy) string {
sum := sha1.Sum([]byte(strings.ToLower(strings.TrimSpace(p.Type)) + "|" + strings.ToLower(strings.TrimSpace(p.Host)) + "|" + fmt.Sprintf("%d", p.Port) + "|" + strings.TrimSpace(p.Username)))
return "proxy_" + hex.EncodeToString(sum[:6])
}
func (c *Config) ClearAccountTokens() {

View File

@@ -32,6 +32,47 @@ func TestLoadStoreClearsTokensFromConfigInput(t *testing.T) {
}
}
func TestLoadStorePreservesProxiesAndAccountProxyAssignment(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"proxies":[
{
"id":"proxy-sh-1",
"name":"Shanghai Exit",
"type":"socks5h",
"host":"127.0.0.1",
"port":1080,
"username":"demo",
"password":"secret"
}
],
"accounts":[
{
"email":"u@example.com",
"password":"p",
"proxy_id":"proxy-sh-1"
}
]
}`)
store := LoadStore()
snap := store.Snapshot()
if len(snap.Proxies) != 1 {
t.Fatalf("expected 1 proxy, got %d", len(snap.Proxies))
}
if snap.Proxies[0].ID != "proxy-sh-1" {
t.Fatalf("unexpected proxy id: %#v", snap.Proxies[0])
}
if snap.Proxies[0].Type != "socks5h" {
t.Fatalf("unexpected proxy type: %#v", snap.Proxies[0])
}
if len(snap.Accounts) != 1 {
t.Fatalf("expected 1 account, got %d", len(snap.Accounts))
}
if snap.Accounts[0].ProxyID != "proxy-sh-1" {
t.Fatalf("expected account proxy assignment preserved, got %#v", snap.Accounts[0])
}
}
func TestLoadStoreDropsLegacyTokenOnlyAccounts(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"accounts":[

View File

@@ -6,6 +6,9 @@ import (
)
func ValidateConfig(c Config) error {
if err := ValidateProxyConfig(c.Proxies); err != nil {
return err
}
if err := ValidateAdminConfig(c.Admin); err != nil {
return err
}
@@ -21,6 +24,55 @@ func ValidateConfig(c Config) error {
if err := ValidateAutoDeleteConfig(c.AutoDelete); err != nil {
return err
}
if err := ValidateAccountProxyReferences(c.Accounts, c.Proxies); err != nil {
return err
}
return nil
}
func ValidateProxyConfig(proxies []Proxy) error {
seen := make(map[string]struct{}, len(proxies))
for _, proxy := range proxies {
proxy = NormalizeProxy(proxy)
if err := ValidateTrimmedString("proxies.id", proxy.ID, true); err != nil {
return err
}
switch proxy.Type {
case "socks5", "socks5h":
default:
return fmt.Errorf("proxies.type must be one of socks5, socks5h")
}
if err := ValidateTrimmedString("proxies.host", proxy.Host, true); err != nil {
return err
}
if err := ValidateIntRange("proxies.port", proxy.Port, 1, 65535, true); err != nil {
return err
}
if _, ok := seen[proxy.ID]; ok {
return fmt.Errorf("duplicate proxy id: %s", proxy.ID)
}
seen[proxy.ID] = struct{}{}
}
return nil
}
func ValidateAccountProxyReferences(accounts []Account, proxies []Proxy) error {
if len(accounts) == 0 {
return nil
}
ids := make(map[string]struct{}, len(proxies))
for _, proxy := range proxies {
ids[NormalizeProxy(proxy).ID] = struct{}{}
}
for _, acc := range accounts {
proxyID := strings.TrimSpace(acc.ProxyID)
if proxyID == "" {
continue
}
if _, ok := ids[proxyID]; !ok {
return fmt.Errorf("account proxy_id references unknown proxy: %s", proxyID)
}
}
return nil
}

View File

@@ -13,6 +13,7 @@ import (
)
func (c *Client) Login(ctx context.Context, acc config.Account) (string, error) {
clients := c.requestClientsForAccount(acc)
payload := map[string]any{
"password": strings.TrimSpace(acc.Password),
"device_id": "deepseek_to_api",
@@ -27,7 +28,7 @@ func (c *Client) Login(ctx context.Context, acc config.Account) (string, error)
} else {
return "", errors.New("missing email/mobile")
}
resp, err := c.postJSON(ctx, c.regular, DeepSeekLoginURL, BaseHeaders, payload)
resp, err := c.postJSON(ctx, clients.regular, DeepSeekLoginURL, BaseHeaders, payload)
if err != nil {
return "", err
}
@@ -52,11 +53,12 @@ func (c *Client) CreateSession(ctx context.Context, a *auth.RequestAuth, maxAtte
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
clients := c.requestClientsForAuth(ctx, a)
attempts := 0
refreshed := false
for attempts < maxAttempts {
headers := c.authHeaders(a.DeepSeekToken)
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekCreateSessionURL, headers, map[string]any{"agent": "chat"})
if err != nil {
config.Logger.Warn("[create_session] request error", "error", err, "account", a.AccountID)
attempts++
@@ -94,11 +96,12 @@ func (c *Client) GetPow(ctx context.Context, a *auth.RequestAuth, maxAttempts in
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
clients := c.requestClientsForAuth(ctx, a)
attempts := 0
refreshed := false
for attempts < maxAttempts {
headers := c.authHeaders(a.DeepSeekToken)
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekCreatePowURL, headers, map[string]any{"target_path": "/api/v0/chat/completion"})
if err != nil {
config.Logger.Warn("[get_pow] request error", "error", err, "account", a.AccountID)
attempts++

View File

@@ -10,18 +10,20 @@ import (
"ds2api/internal/auth"
"ds2api/internal/config"
trans "ds2api/internal/deepseek/transport"
)
func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payload map[string]any, powResp string, maxAttempts int) (*http.Response, error) {
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
clients := c.requestClientsForAuth(ctx, a)
headers := c.authHeaders(a.DeepSeekToken)
headers["x-ds-pow-response"] = powResp
captureSession := c.capture.Start("deepseek_completion", DeepSeekCompletionURL, a.AccountID, payload)
attempts := 0
for attempts < maxAttempts {
resp, err := c.streamPost(ctx, DeepSeekCompletionURL, headers, payload)
resp, err := c.streamPost(ctx, clients.stream, DeepSeekCompletionURL, headers, payload)
if err != nil {
attempts++
time.Sleep(time.Second)
@@ -44,11 +46,12 @@ func (c *Client) CallCompletion(ctx context.Context, a *auth.RequestAuth, payloa
return nil, errors.New("completion failed")
}
func (c *Client) streamPost(ctx context.Context, url string, headers map[string]string, payload any) (*http.Response, error) {
func (c *Client) streamPost(ctx context.Context, doer trans.Doer, url string, headers map[string]string, payload any) (*http.Response, error) {
b, err := json.Marshal(payload)
if err != nil {
return nil, err
}
clients := c.requestClientsFromContext(ctx)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if err != nil {
return nil, err
@@ -56,7 +59,7 @@ func (c *Client) streamPost(ctx context.Context, url string, headers map[string]
for k, v := range headers {
req.Header.Set(k, v)
}
resp, err := c.stream.Do(req)
resp, err := doer.Do(req)
if err != nil {
config.Logger.Warn("[deepseek] fingerprint stream request failed, fallback to std transport", "url", url, "error", err)
req2, reqErr := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
@@ -66,7 +69,7 @@ func (c *Client) streamPost(ctx context.Context, url string, headers map[string]
for k, v := range headers {
req2.Header.Set(k, v)
}
return c.fallbackS.Do(req2)
return clients.fallbackS.Do(req2)
}
return resp, nil
}

View File

@@ -51,6 +51,7 @@ func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionI
if strings.TrimSpace(sessionID) == "" || responseMessageID <= 0 {
return nil, errors.New("missing continue identifiers")
}
clients := c.requestClientsForAuth(ctx, a)
headers := c.authHeaders(a.DeepSeekToken)
headers["x-ds-pow-response"] = powResp
payload := map[string]any{
@@ -60,7 +61,7 @@ func (c *Client) callContinue(ctx context.Context, a *auth.RequestAuth, sessionI
}
config.Logger.Info("[auto_continue] calling continue", "session_id", sessionID, "message_id", responseMessageID)
captureSession := c.capture.Start("deepseek_continue", DeepSeekContinueURL, a.AccountID, payload)
resp, err := c.streamPost(ctx, DeepSeekContinueURL, headers, payload)
resp, err := c.streamPost(ctx, clients.stream, DeepSeekContinueURL, headers, payload)
if err != nil {
return nil, err
}

View File

@@ -3,6 +3,7 @@ package deepseek
import (
"context"
"net/http"
"sync"
"time"
"ds2api/internal/auth"
@@ -24,18 +25,22 @@ type Client struct {
fallback *http.Client
fallbackS *http.Client
maxRetries int
proxyClientsMu sync.RWMutex
proxyClients map[string]requestClients
}
func NewClient(store *config.Store, resolver *auth.Resolver) *Client {
return &Client{
Store: store,
Auth: resolver,
capture: devcapture.Global(),
regular: trans.New(60 * time.Second),
stream: trans.New(0),
fallback: &http.Client{Timeout: 60 * time.Second},
fallbackS: &http.Client{Timeout: 0},
maxRetries: 3,
Store: store,
Auth: resolver,
capture: devcapture.Global(),
regular: trans.New(60 * time.Second),
stream: trans.New(0),
fallback: &http.Client{Timeout: 60 * time.Second},
fallbackS: &http.Client{Timeout: 0},
maxRetries: 3,
proxyClients: map[string]requestClients{},
}
}

View File

@@ -27,6 +27,7 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
if err != nil {
return nil, 0, err
}
clients := c.requestClientsFromContext(ctx)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(b))
if err != nil {
return nil, 0, err
@@ -44,7 +45,7 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
for k, v := range headers {
req2.Header.Set(k, v)
}
resp, err = c.fallback.Do(req2)
resp, err = clients.fallback.Do(req2)
if err != nil {
return nil, 0, err
}
@@ -64,6 +65,7 @@ func (c *Client) postJSONWithStatus(ctx context.Context, doer trans.Doer, url st
}
func (c *Client) getJSONWithStatus(ctx context.Context, doer trans.Doer, url string, headers map[string]string) (map[string]any, int, error) {
clients := c.requestClientsFromContext(ctx)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, 0, err
@@ -81,7 +83,7 @@ func (c *Client) getJSONWithStatus(ctx context.Context, doer trans.Doer, url str
for k, v := range headers {
req2.Header.Set(k, v)
}
resp, err = c.fallback.Do(req2)
resp, err = clients.fallback.Do(req2)
if err != nil {
return nil, 0, err
}

View File

@@ -36,6 +36,7 @@ func (c *Client) GetSessionCount(ctx context.Context, a *auth.RequestAuth, maxAt
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
clients := c.requestClientsForAuth(ctx, a)
stats := &SessionStats{
AccountID: a.AccountID,
@@ -50,7 +51,7 @@ func (c *Client) GetSessionCount(ctx context.Context, a *auth.RequestAuth, maxAt
// 构建请求 URL
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
resp, status, err := c.getJSONWithStatus(ctx, clients.regular, reqURL, headers)
if err != nil {
config.Logger.Warn("[get_session_count] request error", "error", err, "account", a.AccountID)
attempts++
@@ -106,10 +107,11 @@ func (c *Client) GetSessionCount(ctx context.Context, a *auth.RequestAuth, maxAt
// GetSessionCountForToken 直接使用 token 获取会话数量(直通模式)
func (c *Client) GetSessionCountForToken(ctx context.Context, token string) (*SessionStats, error) {
clients := c.requestClientsFromContext(ctx)
headers := c.authHeaders(token)
reqURL := DeepSeekFetchSessionURL + "?lte_cursor.pinned=false"
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
resp, status, err := c.getJSONWithStatus(ctx, clients.regular, reqURL, headers)
if err != nil {
return nil, err
}
@@ -160,7 +162,7 @@ func (c *Client) GetSessionCountAll(ctx context.Context) []*SessionStats {
// 如果没有 token尝试登录获取
if token == "" {
var err error
token, err = c.Login(ctx, acc)
token, err = c.Login(auth.WithAuth(ctx, &auth.RequestAuth{AccountID: acc.Identifier(), Account: acc}), acc)
if err != nil {
results = append(results, &SessionStats{
AccountID: accountID,
@@ -171,7 +173,8 @@ func (c *Client) GetSessionCountAll(ctx context.Context) []*SessionStats {
}
}
stats, err := c.GetSessionCountForToken(ctx, token)
ctxWithAuth := auth.WithAuth(ctx, &auth.RequestAuth{AccountID: acc.Identifier(), Account: acc, DeepSeekToken: token})
stats, err := c.GetSessionCountForToken(ctxWithAuth, token)
if err != nil {
results = append(results, &SessionStats{
AccountID: accountID,
@@ -190,6 +193,7 @@ func (c *Client) GetSessionCountAll(ctx context.Context) []*SessionStats {
// FetchSessionPage 获取会话列表(支持分页)
func (c *Client) FetchSessionPage(ctx context.Context, a *auth.RequestAuth, cursor string) ([]SessionInfo, bool, error) {
clients := c.requestClientsForAuth(ctx, a)
headers := c.authHeaders(a.DeepSeekToken)
// 构建请求 URL
@@ -200,7 +204,7 @@ func (c *Client) FetchSessionPage(ctx context.Context, a *auth.RequestAuth, curs
}
reqURL := DeepSeekFetchSessionURL + "?" + params.Encode()
resp, status, err := c.getJSONWithStatus(ctx, c.regular, reqURL, headers)
resp, status, err := c.getJSONWithStatus(ctx, clients.regular, reqURL, headers)
if err != nil {
return nil, false, err
}

View File

@@ -22,6 +22,7 @@ func (c *Client) DeleteSession(ctx context.Context, a *auth.RequestAuth, session
if maxAttempts <= 0 {
maxAttempts = c.maxRetries
}
clients := c.requestClientsForAuth(ctx, a)
result := &DeleteSessionResult{
SessionID: sessionID,
@@ -42,7 +43,7 @@ func (c *Client) DeleteSession(ctx context.Context, a *auth.RequestAuth, session
"chat_session_id": sessionID,
}
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekDeleteSessionURL, headers, payload)
if err != nil {
config.Logger.Warn("[delete_session] request error", "error", err, "session_id", sessionID)
attempts++
@@ -81,6 +82,7 @@ func (c *Client) DeleteSession(ctx context.Context, a *auth.RequestAuth, session
// DeleteSessionForToken 直接使用 token 删除会话(直通模式)
func (c *Client) DeleteSessionForToken(ctx context.Context, token string, sessionID string) (*DeleteSessionResult, error) {
clients := c.requestClientsFromContext(ctx)
result := &DeleteSessionResult{
SessionID: sessionID,
}
@@ -95,7 +97,7 @@ func (c *Client) DeleteSessionForToken(ctx context.Context, token string, sessio
"chat_session_id": sessionID,
}
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteSessionURL, headers, payload)
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekDeleteSessionURL, headers, payload)
if err != nil {
result.ErrorMessage = err.Error()
return result, err
@@ -114,10 +116,11 @@ func (c *Client) DeleteSessionForToken(ctx context.Context, token string, sessio
// DeleteAllSessions 删除所有会话(谨慎使用)
func (c *Client) DeleteAllSessions(ctx context.Context, a *auth.RequestAuth) error {
clients := c.requestClientsForAuth(ctx, a)
headers := c.authHeaders(a.DeepSeekToken)
payload := map[string]any{}
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
if err != nil {
config.Logger.Warn("[delete_all_sessions] request error", "error", err)
return err
@@ -135,10 +138,11 @@ func (c *Client) DeleteAllSessions(ctx context.Context, a *auth.RequestAuth) err
// DeleteAllSessionsForToken 直接使用 token 删除所有会话(直通模式)
func (c *Client) DeleteAllSessionsForToken(ctx context.Context, token string) error {
clients := c.requestClientsFromContext(ctx)
headers := c.authHeaders(token)
payload := map[string]any{}
resp, status, err := c.postJSONWithStatus(ctx, c.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
resp, status, err := c.postJSONWithStatus(ctx, clients.regular, DeepSeekDeleteAllSessionsURL, headers, payload)
if err != nil {
config.Logger.Warn("[delete_all_sessions_for_token] request error", "error", err)
return err

239
internal/deepseek/proxy.go Normal file
View File

@@ -0,0 +1,239 @@
package deepseek
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"strings"
"time"
"golang.org/x/net/proxy"
"ds2api/internal/auth"
"ds2api/internal/config"
trans "ds2api/internal/deepseek/transport"
)
type requestClients struct {
regular trans.Doer
stream trans.Doer
fallback *http.Client
fallbackS *http.Client
}
type hostLookupFunc func(ctx context.Context, network, host string) ([]string, error)
var proxyConnectivityTestURL = "https://chat.deepseek.com/"
var defaultHostLookup hostLookupFunc = func(ctx context.Context, _ string, host string) ([]string, error) {
return net.DefaultResolver.LookupHost(ctx, host)
}
func proxyDialAddress(ctx context.Context, proxyType, address string, lookup hostLookupFunc) (string, error) {
proxyType = strings.ToLower(strings.TrimSpace(proxyType))
if proxyType != "socks5" {
return address, nil
}
host, port, err := net.SplitHostPort(address)
if err != nil {
return "", err
}
if net.ParseIP(host) != nil {
return address, nil
}
if lookup == nil {
lookup = defaultHostLookup
}
addrs, err := lookup(ctx, "ip", host)
if err != nil {
return "", err
}
if len(addrs) == 0 {
return "", fmt.Errorf("no ip address resolved for %s", host)
}
return net.JoinHostPort(addrs[0], port), nil
}
func proxyCacheKey(proxyCfg config.Proxy) string {
proxyCfg = config.NormalizeProxy(proxyCfg)
return strings.Join([]string{
proxyCfg.ID,
proxyCfg.Type,
strings.ToLower(proxyCfg.Host),
strconv.Itoa(proxyCfg.Port),
proxyCfg.Username,
proxyCfg.Password,
}, "|")
}
func proxyDialContext(proxyCfg config.Proxy) (trans.DialContextFunc, error) {
proxyCfg = config.NormalizeProxy(proxyCfg)
var authCfg *proxy.Auth
if proxyCfg.Username != "" || proxyCfg.Password != "" {
authCfg = &proxy.Auth{User: proxyCfg.Username, Password: proxyCfg.Password}
}
forward := &net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}
dialer, err := proxy.SOCKS5("tcp", net.JoinHostPort(proxyCfg.Host, strconv.Itoa(proxyCfg.Port)), authCfg, forward)
if err != nil {
return nil, err
}
return func(ctx context.Context, network, address string) (net.Conn, error) {
target, err := proxyDialAddress(ctx, proxyCfg.Type, address, defaultHostLookup)
if err != nil {
return nil, err
}
if ctxDialer, ok := dialer.(proxy.ContextDialer); ok {
return ctxDialer.DialContext(ctx, network, target)
}
return dialer.Dial(network, target)
}, nil
}
func (c *Client) defaultRequestClients() requestClients {
return requestClients{
regular: c.regular,
stream: c.stream,
fallback: c.fallback,
fallbackS: c.fallbackS,
}
}
func (c *Client) resolveProxyForAccount(acc config.Account) (config.Proxy, bool) {
if c == nil || c.Store == nil {
return config.Proxy{}, false
}
proxyID := strings.TrimSpace(acc.ProxyID)
if proxyID == "" {
return config.Proxy{}, false
}
snap := c.Store.Snapshot()
for _, proxyCfg := range snap.Proxies {
proxyCfg = config.NormalizeProxy(proxyCfg)
if proxyCfg.ID == proxyID {
return proxyCfg, true
}
}
return config.Proxy{}, false
}
func (c *Client) requestClientsFromContext(ctx context.Context) requestClients {
if a, ok := auth.FromContext(ctx); ok {
return c.requestClientsForAccount(a.Account)
}
return c.defaultRequestClients()
}
func (c *Client) requestClientsForAuth(ctx context.Context, a *auth.RequestAuth) requestClients {
if a != nil {
return c.requestClientsForAccount(a.Account)
}
return c.requestClientsFromContext(ctx)
}
func (c *Client) requestClientsForAccount(acc config.Account) requestClients {
proxyCfg, ok := c.resolveProxyForAccount(acc)
if !ok {
return c.defaultRequestClients()
}
key := proxyCacheKey(proxyCfg)
c.proxyClientsMu.RLock()
cached, ok := c.proxyClients[key]
c.proxyClientsMu.RUnlock()
if ok {
return cached
}
dialContext, err := proxyDialContext(proxyCfg)
if err != nil {
config.Logger.Warn("[proxy] build dialer failed", "proxy_id", proxyCfg.ID, "error", err)
return c.defaultRequestClients()
}
bundle := requestClients{
regular: trans.NewWithDialContext(60*time.Second, dialContext),
stream: trans.NewWithDialContext(0, dialContext),
fallback: trans.NewFallbackClient(60*time.Second, dialContext),
fallbackS: trans.NewFallbackClient(0, dialContext),
}
c.proxyClientsMu.Lock()
if c.proxyClients == nil {
c.proxyClients = make(map[string]requestClients)
}
c.proxyClients[key] = bundle
c.proxyClientsMu.Unlock()
return bundle
}
func applyProxyConnectivityHeaders(req *http.Request) {
if req == nil {
return
}
for key, value := range BaseHeaders {
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" || value == "" {
continue
}
req.Header.Set(key, value)
}
}
func proxyConnectivityStatus(statusCode int) (bool, string) {
switch {
case statusCode >= 200 && statusCode < 300:
return true, fmt.Sprintf("代理可达,目标返回 HTTP %d", statusCode)
case statusCode >= 300 && statusCode < 500:
return true, fmt.Sprintf("代理可达,但目标返回 HTTP %d可能是风控或挑战", statusCode)
default:
return false, fmt.Sprintf("目标返回 HTTP %d", statusCode)
}
}
func TestProxyConnectivity(ctx context.Context, proxyCfg config.Proxy) map[string]any {
start := time.Now()
proxyCfg = config.NormalizeProxy(proxyCfg)
result := map[string]any{
"success": false,
"proxy_id": proxyCfg.ID,
"proxy_type": proxyCfg.Type,
"response_time": 0,
}
if err := config.ValidateProxyConfig([]config.Proxy{proxyCfg}); err != nil {
result["message"] = "代理配置无效: " + err.Error()
return result
}
dialContext, err := proxyDialContext(proxyCfg)
if err != nil {
result["message"] = "代理拨号器初始化失败: " + err.Error()
return result
}
client := trans.NewFallbackClient(15*time.Second, dialContext)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, proxyConnectivityTestURL, nil)
if err != nil {
result["message"] = err.Error()
return result
}
applyProxyConnectivityHeaders(req)
resp, err := client.Do(req)
result["response_time"] = int(time.Since(start).Milliseconds())
if err != nil {
result["message"] = err.Error()
return result
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
config.Logger.Warn("[proxy] close response body failed", "proxy_id", proxyCfg.ID, "error", closeErr)
}
}()
result["status_code"] = resp.StatusCode
result["success"], result["message"] = proxyConnectivityStatus(resp.StatusCode)
return result
}

View File

@@ -0,0 +1,85 @@
package deepseek
import (
"context"
"net/http"
"strings"
"testing"
)
func TestProxyDialAddressUsesLocalResolutionForSocks5(t *testing.T) {
ctx := context.Background()
resolved, err := proxyDialAddress(ctx, "socks5", "example.com:443", func(_ context.Context, network, host string) ([]string, error) {
if network != "ip" {
t.Fatalf("unexpected lookup network: %q", network)
}
if host != "example.com" {
t.Fatalf("unexpected lookup host: %q", host)
}
return []string{"203.0.113.10"}, nil
})
if err != nil {
t.Fatalf("proxyDialAddress returned error: %v", err)
}
if resolved != "203.0.113.10:443" {
t.Fatalf("expected locally resolved address, got %q", resolved)
}
}
func TestProxyDialAddressKeepsHostnameForSocks5h(t *testing.T) {
ctx := context.Background()
lookups := 0
resolved, err := proxyDialAddress(ctx, "socks5h", "example.com:443", func(_ context.Context, network, host string) ([]string, error) {
lookups++
return []string{"203.0.113.10"}, nil
})
if err != nil {
t.Fatalf("proxyDialAddress returned error: %v", err)
}
if resolved != "example.com:443" {
t.Fatalf("expected hostname preserved for remote DNS, got %q", resolved)
}
if lookups != 0 {
t.Fatalf("expected no local DNS lookup for socks5h, got %d", lookups)
}
}
func TestApplyProxyConnectivityHeadersUsesBaseHeaders(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "https://chat.deepseek.com/", nil)
if err != nil {
t.Fatalf("http.NewRequest returned error: %v", err)
}
applyProxyConnectivityHeaders(req)
for key, want := range BaseHeaders {
if got := req.Header.Get(key); got != want {
t.Fatalf("expected header %q=%q, got %q", key, want, got)
}
}
}
func TestProxyConnectivityStatus(t *testing.T) {
cases := []struct {
name string
statusCode int
success bool
wantText string
}{
{name: "ok", statusCode: 200, success: true, wantText: "HTTP 200"},
{name: "challenge", statusCode: 403, success: true, wantText: "风控或挑战"},
{name: "upstream error", statusCode: 502, success: false, wantText: "HTTP 502"},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
success, message := proxyConnectivityStatus(tc.statusCode)
if success != tc.success {
t.Fatalf("expected success=%v, got %v", tc.success, success)
}
if message == "" || !strings.Contains(message, tc.wantText) {
t.Fatalf("expected message to contain %q, got %q", tc.wantText, message)
}
})
}
}

View File

@@ -15,21 +15,33 @@ type Doer interface {
Do(req *http.Request) (*http.Response, error)
}
type DialContextFunc func(ctx context.Context, network, addr string) (net.Conn, error)
type Client struct {
http *http.Client
}
func New(timeout time.Duration) *Client {
return NewWithDialContext(timeout, nil)
}
func NewWithDialContext(timeout time.Duration, dialContext DialContextFunc) *Client {
useEnvProxy := dialContext == nil
if dialContext == nil {
dialContext = (&net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}).DialContext
}
base := &http.Transport{
Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: false,
MaxIdleConns: 200,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
DialContext: (&net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}).DialContext,
DialTLSContext: safariTLSDialer(),
DialContext: dialContext,
DialTLSContext: safariTLSDialer(dialContext),
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
}
if useEnvProxy {
base.Proxy = http.ProxyFromEnvironment
}
return &Client{http: &http.Client{Timeout: timeout, Transport: base}}
}
@@ -37,10 +49,31 @@ func (c *Client) Do(req *http.Request) (*http.Response, error) {
return c.http.Do(req)
}
func safariTLSDialer() func(ctx context.Context, network, addr string) (net.Conn, error) {
var dialer net.Dialer
func NewFallbackClient(timeout time.Duration, dialContext DialContextFunc) *http.Client {
useEnvProxy := dialContext == nil
if dialContext == nil {
dialContext = (&net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}).DialContext
}
base := &http.Transport{
ForceAttemptHTTP2: false,
MaxIdleConns: 200,
MaxIdleConnsPerHost: 100,
IdleConnTimeout: 90 * time.Second,
DialContext: dialContext,
TLSClientConfig: &tls.Config{MinVersion: tls.VersionTLS12},
}
if useEnvProxy {
base.Proxy = http.ProxyFromEnvironment
}
return &http.Client{Timeout: timeout, Transport: base}
}
func safariTLSDialer(dialContext DialContextFunc) func(ctx context.Context, network, addr string) (net.Conn, error) {
if dialContext == nil {
dialContext = (&net.Dialer{Timeout: 15 * time.Second, KeepAlive: 30 * time.Second}).DialContext
}
return func(ctx context.Context, network, addr string) (net.Conn, error) {
plainConn, err := dialer.DialContext(ctx, network, addr)
plainConn, err := dialContext(ctx, network, addr)
if err != nil {
return nil, err
}