package gateway import ( "context" "encoding/json" "fmt" "log" "net/http" "strings" "time" "ai-gateway/internal/config" "ai-gateway/internal/ollama" "ai-gateway/internal/queue" "github.com/google/uuid" ) type Handler struct { cfg *config.Config q *queue.Queue ollama *ollama.Client } func New(cfg *config.Config, q *queue.Queue, oc *ollama.Client) *Handler { return &Handler{cfg: cfg, q: q, ollama: oc} } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { path := r.URL.Path switch { case path == "/healthz" && r.Method == "GET": h.healthz(w, r) case path == "/readyz" && r.Method == "GET": h.readyz(w, r) case path == "/v1/moderate" && r.Method == "POST": h.authMiddleware(h.moderate)(w, r) case path == "/v1/generate" && r.Method == "POST": h.authMiddleware(h.generate)(w, r) case strings.HasPrefix(path, "/v1/jobs/") && r.Method == "GET": h.authMiddleware(h.getJob)(w, r) default: jsonError(w, http.StatusNotFound, "not found") } } func (h *Handler) authMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if h.cfg.InternalToken == "" { next(w, r) return } token := r.Header.Get("X-Internal-Token") if token != h.cfg.InternalToken { jsonError(w, http.StatusUnauthorized, "unauthorized") return } next(w, r) } } func (h *Handler) healthz(w http.ResponseWriter, _ *http.Request) { jsonOK(w, map[string]any{"status": "ok", "time": time.Now().UTC()}) } func (h *Handler) readyz(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), 2*time.Second) defer cancel() status := map[string]any{"time": time.Now().UTC()} ready := true if err := h.q.Ping(ctx); err != nil { status["redis"] = "down" ready = false } else { status["redis"] = "ok" } if err := h.ollama.Healthz(ctx); err != nil { status["ollama"] = "down" } else { status["ollama"] = "ok" } status["ollama_circuit"] = h.ollama.IsAvailable() writerLen, _ := h.q.QueueLen(ctx, "writer") judgeLen, _ := h.q.QueueLen(ctx, "judge") status["queue_writer"] = writerLen status["queue_judge"] = judgeLen if ready { status["status"] = "ready" jsonOK(w, status) } else { status["status"] = "not_ready" w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusServiceUnavailable) json.NewEncoder(w).Encode(status) } } type ModerateRequest struct { Text string `json:"text"` Context map[string]any `json:"context,omitempty"` } func (h *Handler) moderate(w http.ResponseWriter, r *http.Request) { if h.cfg.AIDisabled { jsonOK(w, map[string]any{"allowed": true, "reason": "ai_disabled", "cached": false}) return } var req ModerateRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { jsonError(w, http.StatusBadRequest, "invalid json") return } if req.Text == "" { jsonError(w, http.StatusBadRequest, "text required") return } // Rate limit ok, err := h.q.CheckRate(r.Context(), "global", "moderate", h.cfg.ModerateRateLimit) if err != nil || !ok { jsonError(w, http.StatusTooManyRequests, "rate limit exceeded") return } // Check cache if cached, err := h.q.GetModCache(r.Context(), req.Text); err == nil { var result map[string]any if json.Unmarshal(cached, &result) == nil { result["cached"] = true jsonOK(w, result) return } } // Synchronous fast path for short texts if len(req.Text) <= h.cfg.SyncMaxChars && h.ollama.IsAvailable() { result, err := h.runJudge(r.Context(), req.Text) if err != nil { log.Printf("[moderate] sync judge error: %v", err) // Fail open jsonOK(w, map[string]any{"allowed": true, "reason": "judge_error", "error": err.Error()}) return } if data, err := json.Marshal(result); err == nil { h.q.SetModCache(r.Context(), req.Text, data) } result["cached"] = false jsonOK(w, result) return } // Async path for long texts jobID := uuid.New().String() input, _ := json.Marshal(req) job := &queue.Job{ ID: jobID, Type: "judge", CreatedAt: time.Now().UTC(), Input: input, Status: "queued", } if err := h.q.Enqueue(r.Context(), job); err != nil { jsonError(w, http.StatusInternalServerError, "queue error") return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) json.NewEncoder(w).Encode(map[string]any{"job_id": jobID, "status": "queued"}) } // guardCategories maps LLaMA Guard 3 S-codes to human-readable category names. var guardCategories = map[string]string{ "S1": "violent_crimes", "S2": "non_violent_crimes", "S3": "sex_related_crimes", "S4": "child_sexual_exploitation", "S5": "defamation", "S6": "specialized_advice", "S7": "privacy", "S8": "intellectual_property", "S9": "indiscriminate_weapons", "S10": "hate", "S11": "suicide_self_harm", "S12": "sexual_content", "S13": "elections", "S14": "code_interpreter_abuse", } // highSeverityCodes are categories that should always be severity "high". var highSeverityCodes = map[string]bool{"S1": true, "S3": true, "S4": true, "S9": true} // parseGuardOutput parses LLaMA Guard 3's native output format. // Safe output: "safe" // Unsafe output: "unsafe\nS1,S4" or "unsafe\nS1" func parseGuardOutput(raw string) map[string]any { content := strings.TrimSpace(raw) lower := strings.ToLower(content) if lower == "safe" || strings.HasPrefix(lower, "safe\n") || strings.HasPrefix(lower, "safe ") { return map[string]any{"allowed": true, "categories": []string{}, "severity": "low", "reason": ""} } // Parse "unsafe\nS1,S2,..." categories := []string{} codes := []string{} severity := "medium" lines := strings.Split(content, "\n") if len(lines) > 1 { // Second line has comma-separated S-codes parts := strings.Split(strings.TrimSpace(lines[1]), ",") for _, p := range parts { code := strings.TrimSpace(p) if code == "" { continue } codes = append(codes, code) if name, ok := guardCategories[code]; ok { categories = append(categories, name) } else { categories = append(categories, code) } if highSeverityCodes[code] { severity = "high" } } } if len(categories) == 0 { categories = []string{"policy_violation"} } return map[string]any{ "allowed": false, "categories": categories, "codes": codes, "severity": severity, "reason": strings.Join(categories, ", "), } } func (h *Handler) runJudge(ctx context.Context, text string) (map[string]any, error) { resp, err := h.ollama.Chat(ctx, &ollama.ChatRequest{ Model: "llama-guard3:1b", Messages: []ollama.ChatMessage{ {Role: "user", Content: text}, }, Stream: false, Options: &ollama.ModelOptions{ Temperature: 0.0, NumPredict: 64, }, }) if err != nil { return nil, err } return parseGuardOutput(resp.Message.Content), nil } type GenerateRequest struct { Task string `json:"task"` Input map[string]any `json:"input"` } func (h *Handler) generate(w http.ResponseWriter, r *http.Request) { if h.cfg.AIDisabled { jsonError(w, http.StatusServiceUnavailable, "ai_disabled") return } var req GenerateRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { jsonError(w, http.StatusBadRequest, "invalid json") return } if req.Task == "" { jsonError(w, http.StatusBadRequest, "task required") return } ok, err := h.q.CheckRate(r.Context(), "global", "generate", h.cfg.GenerateRateLimit) if err != nil || !ok { jsonError(w, http.StatusTooManyRequests, "rate limit exceeded") return } jobID := uuid.New().String() input, _ := json.Marshal(req) job := &queue.Job{ ID: jobID, Type: "writer", CreatedAt: time.Now().UTC(), Input: input, Status: "queued", } if err := h.q.Enqueue(r.Context(), job); err != nil { jsonError(w, http.StatusInternalServerError, "queue error") return } w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusAccepted) json.NewEncoder(w).Encode(map[string]any{"job_id": jobID, "status": "queued"}) } func (h *Handler) getJob(w http.ResponseWriter, r *http.Request) { jobID := strings.TrimPrefix(r.URL.Path, "/v1/jobs/") if jobID == "" { jsonError(w, http.StatusBadRequest, "job_id required") return } job, err := h.q.GetJob(r.Context(), jobID) if err != nil { jsonError(w, http.StatusNotFound, "job not found") return } jsonOK(w, job) } func jsonOK(w http.ResponseWriter, data any) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(data) } func jsonError(w http.ResponseWriter, code int, msg string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(code) json.NewEncoder(w).Encode(map[string]string{"error": msg}) } func init() { // Ensure fmt is used (prevent import error in case) _ = fmt.Sprintf }