sojorn/go-backend/internal/services/local_ai_service.go
2026-02-15 00:33:24 -06:00

259 lines
7.1 KiB
Go

package services
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/rs/zerolog/log"
)
// LocalAIService communicates with the on-server AI Gateway (localhost:8099).
// It provides text moderation via llama-guard and content generation via qwen2.5.
// Runs alongside OpenRouter — both engines are available simultaneously.
type LocalAIService struct {
baseURL string
token string
httpClient *http.Client
mu sync.RWMutex
circuitOpen bool
circuitUntil time.Time
circuitWindow time.Duration
}
// LocalAIModerationResult is the response from the local AI gateway /v1/moderate endpoint.
type LocalAIModerationResult struct {
Allowed bool `json:"allowed"`
Categories []string `json:"categories"`
Severity string `json:"severity"`
Reason string `json:"reason"`
Cached bool `json:"cached"`
Error string `json:"error,omitempty"`
}
// LocalAIJobResponse is returned when a job is submitted asynchronously.
type LocalAIJobResponse struct {
JobID string `json:"job_id"`
Status string `json:"status"`
}
// LocalAIJob is the full job object returned when polling.
type LocalAIJob struct {
ID string `json:"id"`
Type string `json:"type"`
CreatedAt time.Time `json:"created_at"`
Status string `json:"status"`
Result json.RawMessage `json:"result,omitempty"`
Error string `json:"error,omitempty"`
}
// LocalAIHealthStatus is returned by the /readyz endpoint.
type LocalAIHealthStatus struct {
Status string `json:"status"`
Redis string `json:"redis"`
Ollama string `json:"ollama"`
OllamaCircuit bool `json:"ollama_circuit"`
QueueWriter int64 `json:"queue_writer"`
QueueJudge int64 `json:"queue_judge"`
}
func NewLocalAIService(baseURL, token string) *LocalAIService {
if baseURL == "" {
return nil
}
return &LocalAIService{
baseURL: strings.TrimRight(baseURL, "/"),
token: token,
httpClient: &http.Client{
Timeout: 90 * time.Second,
Transport: &http.Transport{
MaxIdleConns: 5,
MaxIdleConnsPerHost: 5,
IdleConnTimeout: 60 * time.Second,
},
},
circuitWindow: 30 * time.Second,
}
}
func (s *LocalAIService) isAvailable() bool {
s.mu.RLock()
defer s.mu.RUnlock()
if s.circuitOpen && time.Now().Before(s.circuitUntil) {
return false
}
return true
}
func (s *LocalAIService) tripCircuit() {
s.mu.Lock()
defer s.mu.Unlock()
s.circuitOpen = true
s.circuitUntil = time.Now().Add(s.circuitWindow)
log.Warn().Msg("[local-ai] circuit breaker tripped")
}
func (s *LocalAIService) resetCircuit() {
s.mu.Lock()
defer s.mu.Unlock()
if s.circuitOpen {
s.circuitOpen = false
log.Info().Msg("[local-ai] circuit breaker reset")
}
}
// ModerateText sends text to the local AI gateway for moderation.
// Returns nil result (not an error) if the service is unavailable — caller should fall through to OpenRouter.
func (s *LocalAIService) ModerateText(ctx context.Context, text string) (*LocalAIModerationResult, error) {
if !s.isAvailable() {
return nil, fmt.Errorf("local_ai_unavailable: circuit breaker open")
}
body, _ := json.Marshal(map[string]string{"text": text})
req, err := http.NewRequestWithContext(ctx, "POST", s.baseURL+"/v1/moderate", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("request error: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if s.token != "" {
req.Header.Set("X-Internal-Token", s.token)
}
resp, err := s.httpClient.Do(req)
if err != nil {
s.tripCircuit()
return nil, fmt.Errorf("local_ai_unavailable: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusTooManyRequests {
return nil, fmt.Errorf("local_ai_rate_limited")
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusAccepted {
respBody, _ := io.ReadAll(resp.Body)
s.tripCircuit()
return nil, fmt.Errorf("local_ai error %d: %s", resp.StatusCode, string(respBody))
}
s.resetCircuit()
// Async response (long text)
if resp.StatusCode == http.StatusAccepted {
var jobResp LocalAIJobResponse
json.NewDecoder(resp.Body).Decode(&jobResp)
log.Info().Str("job_id", jobResp.JobID).Msg("[local-ai] moderation queued async")
// For async jobs, return allowed=true (fail open) — the job can be polled later
return &LocalAIModerationResult{Allowed: true, Reason: "async_queued", Severity: "pending"}, nil
}
var result LocalAIModerationResult
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decode error: %w", err)
}
return &result, nil
}
// SubmitGeneration submits a content generation job to the local AI gateway.
// Returns the job ID for polling.
func (s *LocalAIService) SubmitGeneration(ctx context.Context, task string, input map[string]any) (*LocalAIJobResponse, error) {
if !s.isAvailable() {
return nil, fmt.Errorf("local_ai_unavailable: circuit breaker open")
}
body, _ := json.Marshal(map[string]any{"task": task, "input": input})
req, err := http.NewRequestWithContext(ctx, "POST", s.baseURL+"/v1/generate", bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("request error: %w", err)
}
req.Header.Set("Content-Type", "application/json")
if s.token != "" {
req.Header.Set("X-Internal-Token", s.token)
}
resp, err := s.httpClient.Do(req)
if err != nil {
s.tripCircuit()
return nil, fmt.Errorf("local_ai_unavailable: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusAccepted {
respBody, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("local_ai generate error %d: %s", resp.StatusCode, string(respBody))
}
s.resetCircuit()
var result LocalAIJobResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("decode error: %w", err)
}
return &result, nil
}
// GetJob polls a job status from the local AI gateway.
func (s *LocalAIService) GetJob(ctx context.Context, jobID string) (*LocalAIJob, error) {
if !s.isAvailable() {
return nil, fmt.Errorf("local_ai_unavailable: circuit breaker open")
}
req, err := http.NewRequestWithContext(ctx, "GET", s.baseURL+"/v1/jobs/"+jobID, nil)
if err != nil {
return nil, fmt.Errorf("request error: %w", err)
}
if s.token != "" {
req.Header.Set("X-Internal-Token", s.token)
}
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("local_ai_unavailable: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("job not found")
}
var job LocalAIJob
if err := json.NewDecoder(resp.Body).Decode(&job); err != nil {
return nil, fmt.Errorf("decode error: %w", err)
}
return &job, nil
}
// Healthz checks if the local AI gateway is healthy.
func (s *LocalAIService) Healthz(ctx context.Context) (*LocalAIHealthStatus, error) {
if s == nil {
return nil, fmt.Errorf("local AI service not configured")
}
req, err := http.NewRequestWithContext(ctx, "GET", s.baseURL+"/readyz", nil)
if err != nil {
return nil, err
}
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
var status LocalAIHealthStatus
json.NewDecoder(resp.Body).Decode(&status)
return &status, nil
}