From a28c9fb67f99c87faa770d098f9ab93a1b667821 Mon Sep 17 00:00:00 2001 From: CJACK Date: Sun, 5 Apr 2026 21:18:51 +0800 Subject: [PATCH] feat: implement comprehensive configuration validation and integrate into store loading and server initialization. --- app/handler.go | 10 ++- cmd/ds2api/main.go | 6 +- internal/admin/handler_settings_parse.go | 40 ++++++----- internal/admin/handler_settings_test.go | 22 ++++++ internal/admin/settings_validation.go | 32 +-------- internal/config/config_test.go | 50 +++++++++++++ internal/config/store.go | 29 ++++++-- internal/config/validation.go | 91 ++++++++++++++++++++++++ internal/config/validation_test.go | 61 ++++++++++++++++ internal/server/router.go | 10 ++- internal/server/router_health_test.go | 8 ++- 11 files changed, 299 insertions(+), 60 deletions(-) create mode 100644 internal/config/validation.go create mode 100644 internal/config/validation_test.go diff --git a/app/handler.go b/app/handler.go index a8979fd..bc26a67 100644 --- a/app/handler.go +++ b/app/handler.go @@ -3,9 +3,17 @@ package app import ( "net/http" + "ds2api/internal/config" "ds2api/internal/server" ) func NewHandler() http.Handler { - return server.NewApp().Router + app, err := server.NewApp() + if err != nil { + config.Logger.Error("[app] init failed", "error", err) + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + server.WriteUnhandledError(w, err) + }) + } + return app.Router } diff --git a/cmd/ds2api/main.go b/cmd/ds2api/main.go index 1f83702..a081a48 100644 --- a/cmd/ds2api/main.go +++ b/cmd/ds2api/main.go @@ -24,7 +24,11 @@ func main() { config.RefreshLogger() webui.EnsureBuiltOnStartup() _ = auth.AdminKey() - app := server.NewApp() + app, err := server.NewApp() + if err != nil { + config.Logger.Error("server initialization failed", "error", err) + os.Exit(1) + } port := strings.TrimSpace(os.Getenv("PORT")) if port == "" { port = "5001" diff --git a/internal/admin/handler_settings_parse.go b/internal/admin/handler_settings_parse.go index 2cefb77..a9bd699 100644 --- a/internal/admin/handler_settings_parse.go +++ b/internal/admin/handler_settings_parse.go @@ -37,8 +37,8 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi cfg := &config.AdminConfig{} if v, exists := raw["jwt_expire_hours"]; exists { n := intFrom(v) - if n < 1 || n > 720 { - return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") + if err := config.ValidateIntRange("admin.jwt_expire_hours", n, 1, 720, true); err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.JWTExpireHours = n } @@ -49,29 +49,29 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi cfg := &config.RuntimeConfig{} if v, exists := raw["account_max_inflight"]; exists { n := intFrom(v) - if n < 1 || n > 256 { - return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") + if err := config.ValidateIntRange("runtime.account_max_inflight", n, 1, 256, true); err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.AccountMaxInflight = n } if v, exists := raw["account_max_queue"]; exists { n := intFrom(v) - if n < 1 || n > 200000 { - return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") + if err := config.ValidateIntRange("runtime.account_max_queue", n, 1, 200000, true); err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.AccountMaxQueue = n } if v, exists := raw["global_max_inflight"]; exists { n := intFrom(v) - if n < 1 || n > 200000 { - return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") + if err := config.ValidateIntRange("runtime.global_max_inflight", n, 1, 200000, true); err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.GlobalMaxInflight = n } if v, exists := raw["token_refresh_interval_hours"]; exists { n := intFrom(v) - if n < 1 || n > 720 { - return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("runtime.token_refresh_interval_hours must be between 1 and 720") + if err := config.ValidateIntRange("runtime.token_refresh_interval_hours", n, 1, 720, true); err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.TokenRefreshIntervalHours = n } @@ -98,8 +98,8 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi cfg := &config.ResponsesConfig{} if v, exists := raw["store_ttl_seconds"]; exists { n := intFrom(v) - if n < 30 || n > 86400 { - return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") + if err := config.ValidateIntRange("responses.store_ttl_seconds", n, 30, 86400, true); err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, err } cfg.StoreTTLSeconds = n } @@ -110,6 +110,9 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi cfg := &config.EmbeddingsConfig{} if v, exists := raw["provider"]; exists { p := strings.TrimSpace(fmt.Sprintf("%v", v)) + if err := config.ValidateTrimmedString("embeddings.provider", p, false); err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, err + } cfg.Provider = p } embCfg = cfg @@ -143,14 +146,13 @@ func parseSettingsUpdateRequest(req map[string]any) (*config.AdminConfig, *confi cfg := &config.AutoDeleteConfig{} if v, exists := raw["mode"]; exists { mode := strings.ToLower(strings.TrimSpace(fmt.Sprintf("%v", v))) - switch mode { - case "", "none": - cfg.Mode = "none" - case "single", "all": - cfg.Mode = mode - default: - return nil, nil, nil, nil, nil, nil, nil, nil, fmt.Errorf("auto_delete.mode must be one of none, single, all") + if err := config.ValidateAutoDeleteMode(mode); err != nil { + return nil, nil, nil, nil, nil, nil, nil, nil, err } + if mode == "" { + mode = "none" + } + cfg.Mode = mode } if v, exists := raw["sessions"]; exists { cfg.Sessions = boolFrom(v) diff --git a/internal/admin/handler_settings_test.go b/internal/admin/handler_settings_test.go index 159e86f..d698b67 100644 --- a/internal/admin/handler_settings_test.go +++ b/internal/admin/handler_settings_test.go @@ -82,6 +82,28 @@ func TestUpdateSettingsValidationRejectsTokenRefreshInterval(t *testing.T) { } } +func TestUpdateSettingsAllowsEmptyEmbeddingsProvider(t *testing.T) { + h := newAdminTestHandler(t, `{"keys":["k1"]}`) + payload := map[string]any{ + "responses": map[string]any{ + "store_ttl_seconds": 600, + }, + "embeddings": map[string]any{ + "provider": "", + }, + } + b, _ := json.Marshal(payload) + req := httptest.NewRequest(http.MethodPut, "/admin/settings", bytes.NewReader(b)) + rec := httptest.NewRecorder() + h.updateSettings(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("expected 200, got %d body=%s", rec.Code, rec.Body.String()) + } + if got := h.Store.Snapshot().Responses.StoreTTLSeconds; got != 600 { + t.Fatalf("store_ttl_seconds=%d want=600", got) + } +} + func TestUpdateSettingsValidationWithMergedRuntimeSnapshot(t *testing.T) { h := newAdminTestHandler(t, `{ "keys":["k1"], diff --git a/internal/admin/settings_validation.go b/internal/admin/settings_validation.go index 9a03892..c18f955 100644 --- a/internal/admin/settings_validation.go +++ b/internal/admin/settings_validation.go @@ -1,7 +1,6 @@ package admin import ( - "fmt" "strings" "ds2api/internal/config" @@ -16,36 +15,9 @@ func normalizeSettingsConfig(c *config.Config) { } func validateSettingsConfig(c config.Config) error { - if c.Admin.JWTExpireHours != 0 && (c.Admin.JWTExpireHours < 1 || c.Admin.JWTExpireHours > 720) { - return fmt.Errorf("admin.jwt_expire_hours must be between 1 and 720") - } - if err := validateRuntimeSettings(c.Runtime); err != nil { - return err - } - if c.Responses.StoreTTLSeconds != 0 && (c.Responses.StoreTTLSeconds < 30 || c.Responses.StoreTTLSeconds > 86400) { - return fmt.Errorf("responses.store_ttl_seconds must be between 30 and 86400") - } - if c.Embeddings.Provider != "" && strings.TrimSpace(c.Embeddings.Provider) == "" { - return fmt.Errorf("embeddings.provider cannot be empty") - } - return nil + return config.ValidateConfig(c) } func validateRuntimeSettings(runtime config.RuntimeConfig) error { - if runtime.AccountMaxInflight != 0 && (runtime.AccountMaxInflight < 1 || runtime.AccountMaxInflight > 256) { - return fmt.Errorf("runtime.account_max_inflight must be between 1 and 256") - } - if runtime.AccountMaxQueue != 0 && (runtime.AccountMaxQueue < 1 || runtime.AccountMaxQueue > 200000) { - return fmt.Errorf("runtime.account_max_queue must be between 1 and 200000") - } - if runtime.GlobalMaxInflight != 0 && (runtime.GlobalMaxInflight < 1 || runtime.GlobalMaxInflight > 200000) { - return fmt.Errorf("runtime.global_max_inflight must be between 1 and 200000") - } - if runtime.TokenRefreshIntervalHours != 0 && (runtime.TokenRefreshIntervalHours < 1 || runtime.TokenRefreshIntervalHours > 720) { - return fmt.Errorf("runtime.token_refresh_interval_hours must be between 1 and 720") - } - if runtime.AccountMaxInflight > 0 && runtime.GlobalMaxInflight > 0 && runtime.GlobalMaxInflight < runtime.AccountMaxInflight { - return fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") - } - return nil + return config.ValidateRuntimeConfig(runtime) } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index c585b8b..2cc0d3d 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -176,6 +176,56 @@ func TestEnvBackedStoreWritebackDoesNotBootstrapOnInvalidEnvJSON(t *testing.T) { } } +func TestEnvBackedStoreWritebackDoesNotBootstrapOnInvalidSemanticConfig(t *testing.T) { + tmp, err := os.CreateTemp(t.TempDir(), "config-*.json") + if err != nil { + t.Fatalf("create temp config: %v", err) + } + path := tmp.Name() + _ = tmp.Close() + _ = os.Remove(path) + + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["k1"], + "accounts":[{"email":"seed@example.com","password":"p"}], + "runtime":{"account_max_inflight":300} + }`) + t.Setenv("DS2API_CONFIG_PATH", path) + t.Setenv("DS2API_ENV_WRITEBACK", "1") + + cfg, fromEnv, loadErr := loadConfig() + if loadErr == nil { + t.Fatalf("expected loadConfig error for invalid runtime config") + } + if !fromEnv { + t.Fatalf("expected fromEnv=true when env config is the source") + } + if !strings.Contains(loadErr.Error(), "runtime.account_max_inflight") { + t.Fatalf("expected runtime validation error, got %v", loadErr) + } + if len(cfg.Keys) != 1 || len(cfg.Accounts) != 1 { + t.Fatalf("expected env config to be parsed before validation, got keys=%d accounts=%d", len(cfg.Keys), len(cfg.Accounts)) + } + if _, statErr := os.Stat(path); !errors.Is(statErr, os.ErrNotExist) { + t.Fatalf("expected invalid config not to be bootstrapped, stat err=%v", statErr) + } +} + +func TestLoadStoreWithErrorRejectsInvalidRuntimeConfig(t *testing.T) { + t.Setenv("DS2API_CONFIG_JSON", `{ + "keys":["k1"], + "accounts":[{"email":"u@example.com","password":"p"}], + "runtime":{"account_max_inflight":300} + }`) + t.Setenv("DS2API_ENV_WRITEBACK", "0") + + if _, err := LoadStoreWithError(); err == nil { + t.Fatal("expected LoadStoreWithError to reject invalid runtime config") + } else if !strings.Contains(err.Error(), "runtime.account_max_inflight") { + t.Fatalf("expected runtime validation error, got %v", err) + } +} + func TestEnvBackedStoreWritebackFallsBackToPersistedFileOnInvalidEnvJSON(t *testing.T) { tmp, err := os.CreateTemp(t.TempDir(), "config-*.json") if err != nil { diff --git a/internal/config/store.go b/internal/config/store.go index 32b304c..ebee6b0 100644 --- a/internal/config/store.go +++ b/internal/config/store.go @@ -21,16 +21,32 @@ type Store struct { } func LoadStore() *Store { - cfg, fromEnv, err := loadConfig() + store, err := loadStore() if err != nil { Logger.Warn("[config] load failed", "error", err) } - if len(cfg.Keys) == 0 && len(cfg.Accounts) == 0 { + if len(store.cfg.Keys) == 0 && len(store.cfg.Accounts) == 0 { Logger.Warn("[config] empty config loaded") } - s := &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv} - s.rebuildIndexes() - return s + store.rebuildIndexes() + return store +} + +func LoadStoreWithError() (*Store, error) { + store, err := loadStore() + if err != nil { + return nil, err + } + store.rebuildIndexes() + return store, nil +} + +func loadStore() (*Store, error) { + cfg, fromEnv, err := loadConfig() + if validateErr := ValidateConfig(cfg); validateErr != nil { + err = errors.Join(err, validateErr) + } + return &Store{cfg: cfg, path: ConfigPath(), fromEnv: fromEnv}, err } func loadConfig() (Config, bool, error) { @@ -59,6 +75,9 @@ func loadConfig() (Config, bool, error) { } } if errors.Is(fileErr, os.ErrNotExist) { + if validateErr := ValidateConfig(cfg); validateErr != nil { + return cfg, true, validateErr + } if writeErr := writeConfigFile(ConfigPath(), cfg.Clone()); writeErr == nil { return cfg, false, err } else { diff --git a/internal/config/validation.go b/internal/config/validation.go new file mode 100644 index 0000000..eb33abb --- /dev/null +++ b/internal/config/validation.go @@ -0,0 +1,91 @@ +package config + +import ( + "fmt" + "strings" +) + +func ValidateConfig(c Config) error { + if err := ValidateAdminConfig(c.Admin); err != nil { + return err + } + if err := ValidateRuntimeConfig(c.Runtime); err != nil { + return err + } + if err := ValidateResponsesConfig(c.Responses); err != nil { + return err + } + if err := ValidateEmbeddingsConfig(c.Embeddings); err != nil { + return err + } + if err := ValidateAutoDeleteConfig(c.AutoDelete); err != nil { + return err + } + return nil +} + +func ValidateAdminConfig(admin AdminConfig) error { + return ValidateIntRange("admin.jwt_expire_hours", admin.JWTExpireHours, 1, 720, false) +} + +func ValidateRuntimeConfig(runtime RuntimeConfig) error { + if err := ValidateIntRange("runtime.account_max_inflight", runtime.AccountMaxInflight, 1, 256, false); err != nil { + return err + } + if err := ValidateIntRange("runtime.account_max_queue", runtime.AccountMaxQueue, 1, 200000, false); err != nil { + return err + } + if err := ValidateIntRange("runtime.global_max_inflight", runtime.GlobalMaxInflight, 1, 200000, false); err != nil { + return err + } + if err := ValidateIntRange("runtime.token_refresh_interval_hours", runtime.TokenRefreshIntervalHours, 1, 720, false); err != nil { + return err + } + if runtime.AccountMaxInflight > 0 && runtime.GlobalMaxInflight > 0 && runtime.GlobalMaxInflight < runtime.AccountMaxInflight { + return fmt.Errorf("runtime.global_max_inflight must be >= runtime.account_max_inflight") + } + return nil +} + +func ValidateResponsesConfig(responses ResponsesConfig) error { + return ValidateIntRange("responses.store_ttl_seconds", responses.StoreTTLSeconds, 30, 86400, false) +} + +func ValidateEmbeddingsConfig(embeddings EmbeddingsConfig) error { + return ValidateTrimmedString("embeddings.provider", embeddings.Provider, false) +} + +func ValidateAutoDeleteConfig(autoDelete AutoDeleteConfig) error { + return ValidateAutoDeleteMode(autoDelete.Mode) +} + +func ValidateIntRange(name string, value, min, max int, required bool) error { + if value == 0 && !required { + return nil + } + if value < min || value > max { + return fmt.Errorf("%s must be between %d and %d", name, min, max) + } + return nil +} + +func ValidateTrimmedString(name, value string, required bool) error { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + if !required && value == "" { + return nil + } + return fmt.Errorf("%s cannot be empty", name) + } + return nil +} + +func ValidateAutoDeleteMode(mode string) error { + mode = strings.ToLower(strings.TrimSpace(mode)) + switch mode { + case "", "none", "single", "all": + return nil + default: + return fmt.Errorf("auto_delete.mode must be one of none, single, all") + } +} diff --git a/internal/config/validation_test.go b/internal/config/validation_test.go new file mode 100644 index 0000000..00b2929 --- /dev/null +++ b/internal/config/validation_test.go @@ -0,0 +1,61 @@ +package config + +import ( + "strings" + "testing" +) + +func TestValidateConfigRejectsInvalidValues(t *testing.T) { + tests := []struct { + name string + cfg Config + want string + }{ + { + name: "admin", + cfg: Config{Admin: AdminConfig{JWTExpireHours: 721}}, + want: "admin.jwt_expire_hours", + }, + { + name: "runtime relation", + cfg: Config{Runtime: RuntimeConfig{ + AccountMaxInflight: 8, + GlobalMaxInflight: 4, + }}, + want: "runtime.global_max_inflight must be >= runtime.account_max_inflight", + }, + { + name: "responses", + cfg: Config{Responses: ResponsesConfig{StoreTTLSeconds: 10}}, + want: "responses.store_ttl_seconds", + }, + { + name: "embeddings", + cfg: Config{Embeddings: EmbeddingsConfig{Provider: " "}}, + want: "embeddings.provider", + }, + { + name: "auto delete", + cfg: Config{AutoDelete: AutoDeleteConfig{Mode: "maybe"}}, + want: "auto_delete.mode", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := ValidateConfig(tc.cfg) + if err == nil { + t.Fatal("expected validation error") + } + if !strings.Contains(err.Error(), tc.want) { + t.Fatalf("expected %q in error, got %v", tc.want, err) + } + }) + } +} + +func TestValidateConfigAcceptsLegacyAutoDeleteSessions(t *testing.T) { + if err := ValidateConfig(Config{AutoDelete: AutoDeleteConfig{Sessions: true}}); err != nil { + t.Fatalf("expected legacy auto_delete.sessions config to remain valid, got %v", err) + } +} diff --git a/internal/server/router.go b/internal/server/router.go index 7557afb..cf44bdb 100644 --- a/internal/server/router.go +++ b/internal/server/router.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/json" + "fmt" "net/http" "strings" "time" @@ -29,8 +30,11 @@ type App struct { Router http.Handler } -func NewApp() *App { - store := config.LoadStore() +func NewApp() (*App, error) { + store, err := config.LoadStoreWithError() + if err != nil { + return nil, fmt.Errorf("load config: %w", err) + } pool := account.NewPool(store) var dsClient *deepseek.Client resolver := auth.NewResolver(store, pool, func(ctx context.Context, acc config.Account) (string, error) { @@ -85,7 +89,7 @@ func NewApp() *App { http.NotFound(w, req) }) - return &App{Store: store, Pool: pool, Resolver: resolver, DS: dsClient, Router: r} + return &App{Store: store, Pool: pool, Resolver: resolver, DS: dsClient, Router: r}, nil } func timeout(d time.Duration) func(http.Handler) http.Handler { diff --git a/internal/server/router_health_test.go b/internal/server/router_health_test.go index 0f744dd..7c79d31 100644 --- a/internal/server/router_health_test.go +++ b/internal/server/router_health_test.go @@ -7,7 +7,13 @@ import ( ) func TestHealthEndpointsSupportHEAD(t *testing.T) { - app := NewApp() + t.Setenv("DS2API_CONFIG_JSON", `{"keys":["k1"],"accounts":[{"email":"u@example.com","password":"p"}]}`) + t.Setenv("DS2API_ENV_WRITEBACK", "0") + + app, err := NewApp() + if err != nil { + t.Fatalf("NewApp() error: %v", err) + } for _, path := range []string{"/healthz", "/readyz"} { req := httptest.NewRequest(http.MethodHead, path, nil)