mirror of
https://github.com/CJackHwang/ds2api.git
synced 2026-05-06 17:35:30 +08:00
fix: harden webui path and account pool compatibility
This commit is contained in:
@@ -152,3 +152,28 @@ func TestPoolAccountConcurrencyAliasEnv(t *testing.T) {
|
||||
t.Fatalf("unexpected recommended_concurrency: %#v", status["recommended_concurrency"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoolSupportsTokenOnlyAccount(t *testing.T) {
|
||||
t.Setenv("DS2API_ACCOUNT_MAX_INFLIGHT", "1")
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"token":"token-only-account"}]
|
||||
}`)
|
||||
|
||||
pool := NewPool(config.LoadStore())
|
||||
status := pool.Status()
|
||||
if got, ok := status["total"].(int); !ok || got != 1 {
|
||||
t.Fatalf("unexpected total in pool status: %#v", status["total"])
|
||||
}
|
||||
if got, ok := status["available"].(int); !ok || got != 1 {
|
||||
t.Fatalf("unexpected available in pool status: %#v", status["available"])
|
||||
}
|
||||
|
||||
acc, ok := pool.Acquire("", nil)
|
||||
if !ok {
|
||||
t.Fatalf("expected acquire success for token-only account")
|
||||
}
|
||||
if acc.Token != "token-only-account" {
|
||||
t.Fatalf("unexpected token on acquired account: %q", acc.Token)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -226,6 +226,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
hasContent := false
|
||||
keepaliveTicker := time.NewTicker(time.Duration(deepseek.KeepAliveTimeout) * time.Second)
|
||||
defer keepaliveTicker.Stop()
|
||||
keepaliveCountWithoutContent := 0
|
||||
|
||||
sendChunk := func(v any) {
|
||||
b, _ := json.Marshal(v)
|
||||
@@ -301,6 +302,13 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
case <-r.Context().Done():
|
||||
return
|
||||
case <-keepaliveTicker.C:
|
||||
if !hasContent {
|
||||
keepaliveCountWithoutContent++
|
||||
if keepaliveCountWithoutContent >= deepseek.MaxKeepaliveCount {
|
||||
finalize("stop")
|
||||
return
|
||||
}
|
||||
}
|
||||
if hasContent && time.Since(lastContent) > time.Duration(deepseek.StreamIdleTimeout)*time.Second {
|
||||
finalize("stop")
|
||||
return
|
||||
@@ -343,6 +351,7 @@ func (h *Handler) handleStream(w http.ResponseWriter, r *http.Request, resp *htt
|
||||
}
|
||||
hasContent = true
|
||||
lastContent = time.Now()
|
||||
keepaliveCountWithoutContent = 0
|
||||
delta := map[string]any{}
|
||||
if !firstChunkSent {
|
||||
delta["role"] = "assistant"
|
||||
|
||||
@@ -873,7 +873,20 @@ func toStringSlice(v any) ([]string, bool) {
|
||||
}
|
||||
|
||||
func toAccount(m map[string]any) config.Account {
|
||||
return config.Account{Email: strings.TrimSpace(fmt.Sprintf("%v", m["email"])), Mobile: strings.TrimSpace(fmt.Sprintf("%v", m["mobile"])), Password: strings.TrimSpace(fmt.Sprintf("%v", m["password"])), Token: strings.TrimSpace(fmt.Sprintf("%v", m["token"]))}
|
||||
return config.Account{
|
||||
Email: fieldString(m, "email"),
|
||||
Mobile: fieldString(m, "mobile"),
|
||||
Password: fieldString(m, "password"),
|
||||
Token: fieldString(m, "token"),
|
||||
}
|
||||
}
|
||||
|
||||
func fieldString(m map[string]any, key string) string {
|
||||
v, ok := m[key]
|
||||
if !ok || v == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprintf("%v", v))
|
||||
}
|
||||
|
||||
func statusOr(v int, d int) int {
|
||||
|
||||
28
internal/admin/handler_test.go
Normal file
28
internal/admin/handler_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package admin
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestToAccountMissingFieldsRemainEmpty(t *testing.T) {
|
||||
acc := toAccount(map[string]any{
|
||||
"email": "user@example.com",
|
||||
"password": "secret",
|
||||
})
|
||||
if acc.Email != "user@example.com" {
|
||||
t.Fatalf("unexpected email: %q", acc.Email)
|
||||
}
|
||||
if acc.Mobile != "" {
|
||||
t.Fatalf("expected empty mobile, got %q", acc.Mobile)
|
||||
}
|
||||
if acc.Token != "" {
|
||||
t.Fatalf("expected empty token, got %q", acc.Token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldStringNilToEmpty(t *testing.T) {
|
||||
if got := fieldString(map[string]any{"token": nil}, "token"); got != "" {
|
||||
t.Fatalf("expected empty string for nil field, got %q", got)
|
||||
}
|
||||
if got := fieldString(map[string]any{}, "token"); got != "" {
|
||||
t.Fatalf("expected empty string for missing field, got %q", got)
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"log/slog"
|
||||
@@ -41,7 +43,17 @@ func (a Account) Identifier() string {
|
||||
if strings.TrimSpace(a.Email) != "" {
|
||||
return strings.TrimSpace(a.Email)
|
||||
}
|
||||
return strings.TrimSpace(a.Mobile)
|
||||
if strings.TrimSpace(a.Mobile) != "" {
|
||||
return strings.TrimSpace(a.Mobile)
|
||||
}
|
||||
// Backward compatibility: old configs may contain token-only accounts.
|
||||
// Use a stable non-sensitive synthetic id so they can still join the pool.
|
||||
token := strings.TrimSpace(a.Token)
|
||||
if token == "" {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return "token:" + hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
|
||||
41
internal/config/config_test.go
Normal file
41
internal/config/config_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAccountIdentifierFallsBackToTokenHash(t *testing.T) {
|
||||
acc := Account{Token: "example-token-value"}
|
||||
id := acc.Identifier()
|
||||
if !strings.HasPrefix(id, "token:") {
|
||||
t.Fatalf("expected token-prefixed identifier, got %q", id)
|
||||
}
|
||||
if len(id) != len("token:")+16 {
|
||||
t.Fatalf("unexpected identifier length: %d (%q)", len(id), id)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStoreFindAccountWithTokenOnlyIdentifier(t *testing.T) {
|
||||
t.Setenv("DS2API_CONFIG_JSON", `{
|
||||
"keys":["k1"],
|
||||
"accounts":[{"token":"token-only-account"}]
|
||||
}`)
|
||||
|
||||
store := LoadStore()
|
||||
accounts := store.Accounts()
|
||||
if len(accounts) != 1 {
|
||||
t.Fatalf("expected 1 account, got %d", len(accounts))
|
||||
}
|
||||
id := accounts[0].Identifier()
|
||||
if id == "" {
|
||||
t.Fatalf("expected synthetic identifier for token-only account")
|
||||
}
|
||||
found, ok := store.FindAccount(id)
|
||||
if !ok {
|
||||
t.Fatalf("expected FindAccount to locate token-only account by synthetic id")
|
||||
}
|
||||
if found.Token != "token-only-account" {
|
||||
t.Fatalf("unexpected token value: %q", found.Token)
|
||||
}
|
||||
}
|
||||
@@ -21,7 +21,7 @@ type Handler struct {
|
||||
}
|
||||
|
||||
func NewHandler() *Handler {
|
||||
return &Handler{StaticDir: config.StaticAdminDir()}
|
||||
return &Handler{StaticDir: resolveStaticAdminDir(config.StaticAdminDir())}
|
||||
}
|
||||
|
||||
func RegisterRoutes(r chi.Router, h *Handler) {
|
||||
@@ -47,19 +47,20 @@ func (h *Handler) index(w http.ResponseWriter, _ *http.Request) {
|
||||
}
|
||||
|
||||
func (h *Handler) admin(w http.ResponseWriter, r *http.Request) {
|
||||
if fi, err := os.Stat(h.StaticDir); err == nil && fi.IsDir() {
|
||||
h.serveFromDisk(w, r)
|
||||
staticDir := resolveStaticAdminDir(h.StaticDir)
|
||||
if fi, err := os.Stat(staticDir); err == nil && fi.IsDir() {
|
||||
h.serveFromDisk(w, r, staticDir)
|
||||
return
|
||||
}
|
||||
http.Error(w, "WebUI not built. Run `cd webui && npm run build` first.", http.StatusNotFound)
|
||||
}
|
||||
|
||||
func (h *Handler) serveFromDisk(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *Handler) serveFromDisk(w http.ResponseWriter, r *http.Request, staticDir string) {
|
||||
path := strings.TrimPrefix(r.URL.Path, "/admin")
|
||||
path = strings.TrimPrefix(path, "/")
|
||||
if path != "" && strings.Contains(path, ".") {
|
||||
full := filepath.Join(h.StaticDir, filepath.Clean(path))
|
||||
if !strings.HasPrefix(full, h.StaticDir) {
|
||||
full := filepath.Join(staticDir, filepath.Clean(path))
|
||||
if !strings.HasPrefix(full, staticDir) {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
@@ -75,7 +76,7 @@ func (h *Handler) serveFromDisk(w http.ResponseWriter, r *http.Request) {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
index := filepath.Join(h.StaticDir, "index.html")
|
||||
index := filepath.Join(staticDir, "index.html")
|
||||
if _, err := os.Stat(index); err != nil {
|
||||
http.Error(w, "index.html not found", http.StatusNotFound)
|
||||
return
|
||||
@@ -83,3 +84,35 @@ func (h *Handler) serveFromDisk(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Cache-Control", "no-store, must-revalidate")
|
||||
http.ServeFile(w, r, index)
|
||||
}
|
||||
|
||||
func resolveStaticAdminDir(preferred string) string {
|
||||
candidates := []string{preferred}
|
||||
if wd, err := os.Getwd(); err == nil {
|
||||
candidates = append(candidates, filepath.Join(wd, "static/admin"))
|
||||
}
|
||||
if exe, err := os.Executable(); err == nil {
|
||||
exeDir := filepath.Dir(exe)
|
||||
candidates = append(candidates,
|
||||
filepath.Join(exeDir, "static/admin"),
|
||||
filepath.Join(filepath.Dir(exeDir), "static/admin"),
|
||||
)
|
||||
}
|
||||
// Common serverless locations.
|
||||
candidates = append(candidates, "/var/task/static/admin", "/var/task/user/static/admin")
|
||||
|
||||
seen := map[string]struct{}{}
|
||||
for _, c := range candidates {
|
||||
c = filepath.Clean(strings.TrimSpace(c))
|
||||
if c == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[c]; ok {
|
||||
continue
|
||||
}
|
||||
seen[c] = struct{}{}
|
||||
if fi, err := os.Stat(c); err == nil && fi.IsDir() {
|
||||
return c
|
||||
}
|
||||
}
|
||||
return filepath.Clean(preferred)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user