feat: implement comprehensive configuration validation and integrate into store loading and server initialization.

This commit is contained in:
CJACK
2026-04-05 21:18:51 +08:00
parent 585d35e592
commit a28c9fb67f
11 changed files with 299 additions and 60 deletions

View File

@@ -3,9 +3,17 @@ package app
import (
"net/http"
"ds2api/internal/config"
"ds2api/internal/server"
)
func NewHandler() http.Handler {
return server.NewApp().Router
app, err := server.NewApp()
if err != nil {
config.Logger.Error("[app] init failed", "error", err)
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
server.WriteUnhandledError(w, err)
})
}
return app.Router
}

View File

@@ -24,7 +24,11 @@ func main() {
config.RefreshLogger()
webui.EnsureBuiltOnStartup()
_ = auth.AdminKey()
app := server.NewApp()
app, err := server.NewApp()
if err != nil {
config.Logger.Error("server initialization failed", "error", err)
os.Exit(1)
}
port := strings.TrimSpace(os.Getenv("PORT"))
if port == "" {
port = "5001"

View File

@@ -37,8 +37,8 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
cfg := &config.AdminConfig{}
if v, exists := raw["jwt_expire_hours"]; exists {
n := intFrom(v)
if n < 1 || n > 720 {
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
if err := config.ValidateIntRange("admin.jwt_expire_hours", n, 1, 720, true); err != nil {
return nil, nil, nil, nil, nil, nil, nil, nil, err
}
cfg.JWTExpireHours = n
}
@@ -49,29 +49,29 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
cfg := &config.RuntimeConfig{}
if v, exists := raw["account_max_inflight"]; exists {
n := intFrom(v)
if n < 1 || n > 256 {
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
if err := config.ValidateIntRange("runtime.account_max_inflight", n, 1, 256, true); err != nil {
return nil, nil, nil, nil, nil, nil, nil, nil, err
}
cfg.AccountMaxInflight = n
}
if v, exists := raw["account_max_queue"]; exists {
n := intFrom(v)
if n < 1 || n > 200000 {
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
if err := config.ValidateIntRange("runtime.account_max_queue", n, 1, 200000, true); err != nil {
return nil, nil, nil, nil, nil, nil, nil, nil, err
}
cfg.AccountMaxQueue = n
}
if v, exists := raw["global_max_inflight"]; exists {
n := intFrom(v)
if n < 1 || n > 200000 {
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
if err := config.ValidateIntRange("runtime.global_max_inflight", n, 1, 200000, true); err != nil {
return nil, nil, nil, nil, nil, nil, nil, nil, err
}
cfg.GlobalMaxInflight = n
}
if v, exists := raw["token_refresh_interval_hours"]; exists {
n := intFrom(v)
if n < 1 || n > 720 {
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.token_refresh_interval_hours must be between 1 and 720")
if err := config.ValidateIntRange("runtime.token_refresh_interval_hours", n, 1, 720, true); err != nil {
return nil, nil, nil, nil, nil, nil, nil, nil, err
}
cfg.TokenRefreshIntervalHours = n
}
@@ -98,8 +98,8 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
cfg := &config.ResponsesConfig{}
if v, exists := raw["store_ttl_seconds"]; exists {
n := intFrom(v)
if n < 30 || n > 86400 {
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
if err := config.ValidateIntRange("responses.store_ttl_seconds", n, 30, 86400, true); err != nil {
return nil, nil, nil, nil, nil, nil, nil, nil, err
}
cfg.StoreTTLSeconds = n
}
@@ -110,6 +110,9 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
cfg := &config.EmbeddingsConfig{}
if v, exists := raw["provider"]; exists {
p := strings.TrimSpace(fmt.Sprintf("%v", v))
if err := config.ValidateTrimmedString("embeddings.provider", p, false); err != nil {
return nil, nil, nil, nil, nil, nil, nil, nil, err
}
cfg.Provider = p
}
embCfg = cfg
@@ -143,14 +146,13 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi
cfg := &config.AutoDeleteConfig{}
if v, exists := raw["mode"]; exists {
mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v)))
switch mode {
case "", "none":
cfg.Mode = "none"
case "single", "all":
cfg.Mode = mode
default:
return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("auto_delete.mode must be one of none, single, all")
if err := config.ValidateAutoDeleteMode(mode); err != nil {
return nil, nil, nil, nil, nil, nil, nil, nil, err
}
if mode == "" {
mode = "none"
}
cfg.Mode = mode
}
if v, exists := raw["sessions"]; exists {
cfg.Sessions = boolFrom(v)

View File

@@ -82,6 +82,28 @@ func TestUpdateSettingsValidationRejectsTokenRefreshInterval(t *testing.T) {
}
}
func TestUpdateSettingsAllowsEmptyEmbeddingsProvider(t *testing.T) {
h := newAdminTestHandler(t, `{"keys":["k1"]}`)
payload := map[string]any{
"responses": map[string]any{
"store_ttl_seconds": 600,
},
"embeddings": map[string]any{
"provider": "",
},
}
b, _ := json.Marshal(payload)
req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b))
rec := httptest.NewRecorder()
h.updateSettings(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String())
}
if got := h.Store.Snapshot().Responses.StoreTTLSeconds; got != 600 {
t.Fatalf("store_ttl_seconds=%d want=600", got)
}
}
func TestUpdateSettingsValidationWithMergedRuntimeSnapshot(t *testing.T) {
h := newAdminTestHandler(t, `{
"keys":["k1"],

View File

@@ -1,7 +1,6 @@
package admin
import (
"fmt"
"strings"
"ds2api/internal/config"
@@ -16,36 +15,9 @@ func normalizeSettingsConfig(c *config.Config) {
}
func validateSettingsConfig(c config.Config) error {
if c.Admin.JWTExpireHours != 0 && (c.Admin.JWTExpireHours < 1 || c.Admin.JWTExpireHours > 720) {
return fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720")
}
if err := validateRuntimeSettings(c.Runtime); err != nil {
return err
}
if c.Responses.StoreTTLSeconds != 0 && (c.Responses.StoreTTLSeconds < 30 || c.Responses.StoreTTLSeconds > 86400) {
return fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400")
}
if c.Embeddings.Provider != "" && strings.TrimSpace(c.Embeddings.Provider) == "" {
return fmt.Errorf("embeddings.provider cannot be empty")
}
return nil
return config.ValidateConfig(c)
}
func validateRuntimeSettings(runtime config.RuntimeConfig) error {
if runtime.AccountMaxInflight != 0 && (runtime.AccountMaxInflight < 1 || runtime.AccountMaxInflight > 256) {
return fmt.Errorf("runtime.account_max_inflight must be between 1 and 256")
}
if runtime.AccountMaxQueue != 0 && (runtime.AccountMaxQueue < 1 || runtime.AccountMaxQueue > 200000) {
return fmt.Errorf("runtime.account_max_queue must be between 1 and 200000")
}
if runtime.GlobalMaxInflight != 0 && (runtime.GlobalMaxInflight < 1 || runtime.GlobalMaxInflight > 200000) {
return fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000")
}
if runtime.TokenRefreshIntervalHours != 0 && (runtime.TokenRefreshIntervalHours < 1 || runtime.TokenRefreshIntervalHours > 720) {
return fmt.Errorf("runtime.token_refresh_interval_hours must be between 1 and 720")
}
if runtime.AccountMaxInflight > 0 && runtime.GlobalMaxInflight > 0 && runtime.GlobalMaxInflight < runtime.AccountMaxInflight {
return fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
}
return nil
return config.ValidateRuntimeConfig(runtime)
}

View File

@@ -176,6 +176,56 @@ func TestEnvBackedStoreWritebackDoesNotBootstrapOnInvalidEnvJSON(t *testing.T) {
}
}
func TestEnvBackedStoreWritebackDoesNotBootstrapOnInvalidSemanticConfig(t *testing.T) {
tmp, err := os.CreateTemp(t.TempDir(), "config-*.json")
if err != nil {
t.Fatalf("create temp config: %v", err)
}
path := tmp.Name()
_ = tmp.Close()
_ = os.Remove(path)
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["k1"],
"accounts":[{"email":"seed@example.com","password":"p"}],
"runtime":{"account_max_inflight":300}
}`)
t.Setenv("DS2API_CONFIG_PATH", path)
t.Setenv("DS2API_ENV_WRITEBACK", "1")
cfg, fromEnv, loadErr := loadConfig()
if loadErr == nil {
t.Fatalf("expected loadConfig error for invalid runtime config")
}
if !fromEnv {
t.Fatalf("expected fromEnv=true when env config is the source")
}
if !strings.Contains(loadErr.Error(), "runtime.account_max_inflight") {
t.Fatalf("expected runtime validation error, got %v", loadErr)
}
if len(cfg.Keys) != 1 || len(cfg.Accounts) != 1 {
t.Fatalf("expected env config to be parsed before validation, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts))
}
if _, statErr := os.Stat(path); !errors.Is(statErr, os.ErrNotExist) {
t.Fatalf("expected invalid config not to be bootstrapped, stat err=%v", statErr)
}
}
func TestLoadStoreWithErrorRejectsInvalidRuntimeConfig(t *testing.T) {
t.Setenv("DS2API_CONFIG_JSON", `{
"keys":["k1"],
"accounts":[{"email":"u@example.com","password":"p"}],
"runtime":{"account_max_inflight":300}
}`)
t.Setenv("DS2API_ENV_WRITEBACK", "0")
if _, err := LoadStoreWithError(); err == nil {
t.Fatal("expected LoadStoreWithError to reject invalid runtime config")
} else if !strings.Contains(err.Error(), "runtime.account_max_inflight") {
t.Fatalf("expected runtime validation error, got %v", err)
}
}
func TestEnvBackedStoreWritebackFallsBackToPersistedFileOnInvalidEnvJSON(t *testing.T) {
tmp, err := os.CreateTemp(t.TempDir(), "config-*.json")
if err != nil {

View File

@@ -21,16 +21,32 @@ type Store struct {
}
func LoadStore() *Store {
cfg, fromEnv, err := loadConfig()
store, err := loadStore()
if err != nil {
Logger.Warn("[config] load failed", "error", err)
}
if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 {
if len(store.cfg.Keys) == 0 && len(store.cfg.Accounts) == 0 {
Logger.Warn("[config] empty config loaded")
}
s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}
s.rebuildIndexes()
return s
store.rebuildIndexes()
return store
}
func LoadStoreWithError() (*Store, error) {
store, err := loadStore()
if err != nil {
return nil, err
}
store.rebuildIndexes()
return store, nil
}
func loadStore() (*Store, error) {
cfg, fromEnv, err := loadConfig()
if validateErr := ValidateConfig(cfg); validateErr != nil {
err = errors.Join(err, validateErr)
}
return &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}, err
}
func loadConfig() (Config, bool, error) {
@@ -59,6 +75,9 @@ func loadConfig() (Config, bool, error) {
}
}
if errors.Is(fileErr, os.ErrNotExist) {
if validateErr := ValidateConfig(cfg); validateErr != nil {
return cfg, true, validateErr
}
if writeErr := writeConfigFile(ConfigPath(), cfg.Clone()); writeErr == nil {
return cfg, false, err
} else {

View File

@@ -0,0 +1,91 @@
package config
import (
"fmt"
"strings"
)
func ValidateConfig(c Config) error {
if err := ValidateAdminConfig(c.Admin); err != nil {
return err
}
if err := ValidateRuntimeConfig(c.Runtime); err != nil {
return err
}
if err := ValidateResponsesConfig(c.Responses); err != nil {
return err
}
if err := ValidateEmbeddingsConfig(c.Embeddings); err != nil {
return err
}
if err := ValidateAutoDeleteConfig(c.AutoDelete); err != nil {
return err
}
return nil
}
func ValidateAdminConfig(admin AdminConfig) error {
return ValidateIntRange("admin.jwt_expire_hours", admin.JWTExpireHours, 1, 720, false)
}
func ValidateRuntimeConfig(runtime RuntimeConfig) error {
if err := ValidateIntRange("runtime.account_max_inflight", runtime.AccountMaxInflight, 1, 256, false); err != nil {
return err
}
if err := ValidateIntRange("runtime.account_max_queue", runtime.AccountMaxQueue, 1, 200000, false); err != nil {
return err
}
if err := ValidateIntRange("runtime.global_max_inflight", runtime.GlobalMaxInflight, 1, 200000, false); err != nil {
return err
}
if err := ValidateIntRange("runtime.token_refresh_interval_hours", runtime.TokenRefreshIntervalHours, 1, 720, false); err != nil {
return err
}
if runtime.AccountMaxInflight > 0 && runtime.GlobalMaxInflight > 0 && runtime.GlobalMaxInflight < runtime.AccountMaxInflight {
return fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight")
}
return nil
}
func ValidateResponsesConfig(responses ResponsesConfig) error {
return ValidateIntRange("responses.store_ttl_seconds", responses.StoreTTLSeconds, 30, 86400, false)
}
func ValidateEmbeddingsConfig(embeddings EmbeddingsConfig) error {
return ValidateTrimmedString("embeddings.provider", embeddings.Provider, false)
}
func ValidateAutoDeleteConfig(autoDelete AutoDeleteConfig) error {
return ValidateAutoDeleteMode(autoDelete.Mode)
}
func ValidateIntRange(name string, value, min, max int, required bool) error {
if value == 0 && !required {
return nil
}
if value < min || value > max {
return fmt.Errorf("%s must be between %d and %d", name, min, max)
}
return nil
}
func ValidateTrimmedString(name, value string, required bool) error {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
if !required && value == "" {
return nil
}
return fmt.Errorf("%s cannot be empty", name)
}
return nil
}
func ValidateAutoDeleteMode(mode string) error {
mode = strings.ToLower(strings.TrimSpace(mode))
switch mode {
case "", "none", "single", "all":
return nil
default:
return fmt.Errorf("auto_delete.mode must be one of none, single, all")
}
}

View File

@@ -0,0 +1,61 @@
package config
import (
"strings"
"testing"
)
func TestValidateConfigRejectsInvalidValues(t *testing.T) {
tests := []struct {
name string
cfg Config
want string
}{
{
name: "admin",
cfg: Config{Admin: AdminConfig{JWTExpireHours: 721}},
want: "admin.jwt_expire_hours",
},
{
name: "runtime relation",
cfg: Config{Runtime: RuntimeConfig{
AccountMaxInflight: 8,
GlobalMaxInflight: 4,
}},
want: "runtime.global_max_inflight must be >= runtime.account_max_inflight",
},
{
name: "responses",
cfg: Config{Responses: ResponsesConfig{StoreTTLSeconds: 10}},
want: "responses.store_ttl_seconds",
},
{
name: "embeddings",
cfg: Config{Embeddings: EmbeddingsConfig{Provider: " "}},
want: "embeddings.provider",
},
{
name: "auto delete",
cfg: Config{AutoDelete: AutoDeleteConfig{Mode: "maybe"}},
want: "auto_delete.mode",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := ValidateConfig(tc.cfg)
if err == nil {
t.Fatal("expected validation error")
}
if !strings.Contains(err.Error(), tc.want) {
t.Fatalf("expected %q in error, got %v", tc.want, err)
}
})
}
}
func TestValidateConfigAcceptsLegacyAutoDeleteSessions(t *testing.T) {
if err := ValidateConfig(Config{AutoDelete: AutoDeleteConfig{Sessions: true}}); err != nil {
t.Fatalf("expected legacy auto_delete.sessions config to remain valid, got %v", err)
}
}

View File

@@ -3,6 +3,7 @@ package server
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
@@ -29,8 +30,11 @@ type App struct {
Router http.Handler
}
func NewApp() *App {
store := config.LoadStore()
func NewApp() (*App, error) {
store, err := config.LoadStoreWithError()
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
pool := account.NewPool(store)
var dsClient *deepseek.Client
resolver := auth.NewResolver(store, pool, func(ctx context.Context, acc config.Account) (string, error) {
@@ -85,7 +89,7 @@ func NewApp() *App {
http.NotFound(w, req)
})
return &App{Store: store, Pool: pool, Resolver: resolver, DS: dsClient, Router: r}
return &App{Store: store, Pool: pool, Resolver: resolver, DS: dsClient, Router: r}, nil
}
func timeout(d time.Duration) func(http.Handler) http.Handler {

View File

@@ -7,7 +7,13 @@ import (
)
func TestHealthEndpointsSupportHEAD(t *testing.T) {
app := NewApp()
t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@example.com","password":"p"}]}`)
t.Setenv("DS2API_ENV_WRITEBACK", "0")
app, err := NewApp()
if err != nil {
t.Fatalf("NewApp() error: %v", err)
}
for _, path := range []string{"/healthz", "/readyz"} {
req := httptest.NewRequest(http.MethodHead, path, nil)