sojorn/go-backend/internal/services/openrouter_service.go

367 lines
12 KiB
Go

package services
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5/pgxpool"
)
// OpenRouterService handles interactions with the OpenRouter API
type OpenRouterService struct {
pool *pgxpool.Pool
httpClient *http.Client
apiKey string
// Cached model list
modelCache []OpenRouterModel
modelCacheMu sync.RWMutex
modelCacheTime time.Time
}
// OpenRouterModel represents a model available on OpenRouter
type OpenRouterModel struct {
ID string `json:"id"`
Name string `json:"name"`
Description string `json:"description,omitempty"`
Pricing OpenRouterPricing `json:"pricing"`
ContextLength int `json:"context_length"`
Architecture map[string]any `json:"architecture,omitempty"`
TopProvider map[string]any `json:"top_provider,omitempty"`
PerRequestLimits map[string]any `json:"per_request_limits,omitempty"`
}
type OpenRouterPricing struct {
Prompt string `json:"prompt"`
Completion string `json:"completion"`
Image string `json:"image,omitempty"`
Request string `json:"request,omitempty"`
}
// ModerationConfigEntry represents a row in ai_moderation_config
type ModerationConfigEntry struct {
ID string `json:"id"`
ModerationType string `json:"moderation_type"`
ModelID string `json:"model_id"`
ModelName string `json:"model_name"`
SystemPrompt string `json:"system_prompt"`
Enabled bool `json:"enabled"`
UpdatedAt time.Time `json:"updated_at"`
UpdatedBy *string `json:"updated_by,omitempty"`
}
// OpenRouterChatMessage represents a message in a chat completion request
type OpenRouterChatMessage struct {
Role string `json:"role"`
Content any `json:"content"`
}
// OpenRouterChatRequest represents a chat completion request
type OpenRouterChatRequest struct {
Model string `json:"model"`
Messages []OpenRouterChatMessage `json:"messages"`
}
// OpenRouterChatResponse represents a chat completion response
type OpenRouterChatResponse struct {
ID string `json:"id"`
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
func NewOpenRouterService(pool *pgxpool.Pool, apiKey string) *OpenRouterService {
return &OpenRouterService{
pool: pool,
apiKey: apiKey,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
}
}
// ListModels fetches available models from OpenRouter, with 1-hour cache
func (s *OpenRouterService) ListModels(ctx context.Context) ([]OpenRouterModel, error) {
s.modelCacheMu.RLock()
if len(s.modelCache) > 0 && time.Since(s.modelCacheTime) < time.Hour {
cached := s.modelCache
s.modelCacheMu.RUnlock()
return cached, nil
}
s.modelCacheMu.RUnlock()
req, err := http.NewRequestWithContext(ctx, "GET", "https://openrouter.ai/api/v1/models", nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
if s.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+s.apiKey)
}
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to fetch models: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("OpenRouter API error %d: %s", resp.StatusCode, string(body))
}
var result struct {
Data []OpenRouterModel `json:"data"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, fmt.Errorf("failed to decode models: %w", err)
}
s.modelCacheMu.Lock()
s.modelCache = result.Data
s.modelCacheTime = time.Now()
s.modelCacheMu.Unlock()
return result.Data, nil
}
// GetModerationConfigs returns all moderation type configurations
func (s *OpenRouterService) GetModerationConfigs(ctx context.Context) ([]ModerationConfigEntry, error) {
rows, err := s.pool.Query(ctx, `
SELECT id, moderation_type, model_id, model_name, system_prompt, enabled, updated_at, updated_by
FROM ai_moderation_config
ORDER BY moderation_type
`)
if err != nil {
return nil, fmt.Errorf("failed to query configs: %w", err)
}
defer rows.Close()
var configs []ModerationConfigEntry
for rows.Next() {
var c ModerationConfigEntry
if err := rows.Scan(&c.ID, &c.ModerationType, &c.ModelID, &c.ModelName, &c.SystemPrompt, &c.Enabled, &c.UpdatedAt, &c.UpdatedBy); err != nil {
return nil, err
}
configs = append(configs, c)
}
return configs, nil
}
// GetModerationConfig returns config for a specific moderation type
func (s *OpenRouterService) GetModerationConfig(ctx context.Context, moderationType string) (*ModerationConfigEntry, error) {
var c ModerationConfigEntry
err := s.pool.QueryRow(ctx, `
SELECT id, moderation_type, model_id, model_name, system_prompt, enabled, updated_at, updated_by
FROM ai_moderation_config WHERE moderation_type = $1
`, moderationType).Scan(&c.ID, &c.ModerationType, &c.ModelID, &c.ModelName, &c.SystemPrompt, &c.Enabled, &c.UpdatedAt, &c.UpdatedBy)
if err != nil {
return nil, err
}
return &c, nil
}
// SetModerationConfig upserts a moderation config
func (s *OpenRouterService) SetModerationConfig(ctx context.Context, moderationType, modelID, modelName, systemPrompt string, enabled bool, updatedBy string) error {
_, err := s.pool.Exec(ctx, `
INSERT INTO ai_moderation_config (moderation_type, model_id, model_name, system_prompt, enabled, updated_by, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, NOW())
ON CONFLICT (moderation_type)
DO UPDATE SET model_id = $2, model_name = $3, system_prompt = $4, enabled = $5, updated_by = $6, updated_at = NOW()
`, moderationType, modelID, modelName, systemPrompt, enabled, updatedBy)
return err
}
// ModerateText sends text content to the configured model for moderation
func (s *OpenRouterService) ModerateText(ctx context.Context, content string) (*ModerationResult, error) {
config, err := s.GetModerationConfig(ctx, "text")
if err != nil || !config.Enabled || config.ModelID == "" {
return nil, fmt.Errorf("text moderation not configured")
}
return s.callModel(ctx, config.ModelID, config.SystemPrompt, content, nil)
}
// ModerateImage sends an image URL to a vision model for moderation
func (s *OpenRouterService) ModerateImage(ctx context.Context, imageURL string) (*ModerationResult, error) {
config, err := s.GetModerationConfig(ctx, "image")
if err != nil || !config.Enabled || config.ModelID == "" {
return nil, fmt.Errorf("image moderation not configured")
}
return s.callModel(ctx, config.ModelID, config.SystemPrompt, "", []string{imageURL})
}
// ModerateVideo sends video frame URLs to a vision model for moderation
func (s *OpenRouterService) ModerateVideo(ctx context.Context, frameURLs []string) (*ModerationResult, error) {
config, err := s.GetModerationConfig(ctx, "video")
if err != nil || !config.Enabled || config.ModelID == "" {
return nil, fmt.Errorf("video moderation not configured")
}
return s.callModel(ctx, config.ModelID, config.SystemPrompt, "These are 3 frames extracted from a short video. Analyze all frames for policy violations.", frameURLs)
}
// ModerationResult is the parsed response from OpenRouter moderation
type ModerationResult struct {
Flagged bool `json:"flagged"`
Reason string `json:"reason"`
Hate float64 `json:"hate"`
Greed float64 `json:"greed"`
Delusion float64 `json:"delusion"`
RawContent string `json:"raw_content"`
}
// callModel sends a chat completion request to OpenRouter
func (s *OpenRouterService) callModel(ctx context.Context, modelID, systemPrompt, textContent string, imageURLs []string) (*ModerationResult, error) {
if s.apiKey == "" {
return nil, fmt.Errorf("OpenRouter API key not configured")
}
messages := []OpenRouterChatMessage{}
// System prompt
if systemPrompt == "" {
systemPrompt = defaultModerationSystemPrompt
}
messages = append(messages, OpenRouterChatMessage{Role: "system", Content: systemPrompt})
// User message — text only or multimodal (text + images)
if len(imageURLs) > 0 {
// Multimodal content array
parts := []map[string]any{}
if textContent != "" {
parts = append(parts, map[string]any{"type": "text", "text": textContent})
}
for _, url := range imageURLs {
parts = append(parts, map[string]any{
"type": "image_url",
"image_url": map[string]string{"url": url},
})
}
messages = append(messages, OpenRouterChatMessage{Role: "user", Content: parts})
} else {
messages = append(messages, OpenRouterChatMessage{Role: "user", Content: textContent})
}
reqBody := OpenRouterChatRequest{
Model: modelID,
Messages: messages,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", "https://openrouter.ai/api/v1/chat/completions", bytes.NewBuffer(jsonBody))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+s.apiKey)
req.Header.Set("HTTP-Referer", "https://sojorn.net")
req.Header.Set("X-Title", "Sojorn Moderation")
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("OpenRouter request failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("OpenRouter API error %d: %s", resp.StatusCode, string(body))
}
var chatResp OpenRouterChatResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if len(chatResp.Choices) == 0 {
return nil, fmt.Errorf("no response from model")
}
raw := chatResp.Choices[0].Message.Content
return parseModerationResponse(raw), nil
}
// parseModerationResponse tries to extract structured moderation data from model output
func parseModerationResponse(raw string) *ModerationResult {
result := &ModerationResult{RawContent: raw}
// Try to parse JSON from the response
// Models may wrap JSON in markdown code blocks
cleaned := raw
if idx := strings.Index(cleaned, "```json"); idx >= 0 {
cleaned = cleaned[idx+7:]
if end := strings.Index(cleaned, "```"); end >= 0 {
cleaned = cleaned[:end]
}
} else if idx := strings.Index(cleaned, "```"); idx >= 0 {
cleaned = cleaned[idx+3:]
if end := strings.Index(cleaned, "```"); end >= 0 {
cleaned = cleaned[:end]
}
}
cleaned = strings.TrimSpace(cleaned)
var parsed struct {
Flagged bool `json:"flagged"`
Reason string `json:"reason"`
Hate float64 `json:"hate"`
Greed float64 `json:"greed"`
Delusion float64 `json:"delusion"`
}
if err := json.Unmarshal([]byte(cleaned), &parsed); err == nil {
result.Flagged = parsed.Flagged
result.Reason = parsed.Reason
result.Hate = parsed.Hate
result.Greed = parsed.Greed
result.Delusion = parsed.Delusion
return result
}
// Fallback: check for keywords in raw text
lower := strings.ToLower(raw)
if strings.Contains(lower, "violation") || strings.Contains(lower, "inappropriate") || strings.Contains(lower, "flagged") {
result.Flagged = true
result.Reason = "Content flagged by AI moderation"
}
return result
}
const defaultModerationSystemPrompt = `You are a content moderation AI for Sojorn, a social media platform.
Analyze the provided content for policy violations.
Respond ONLY with a JSON object in this exact format:
{
"flagged": true/false,
"reason": "brief reason if flagged, empty string if not",
"hate": 0.0-1.0,
"greed": 0.0-1.0,
"delusion": 0.0-1.0
}
Scoring guide (Three Poisons framework):
- hate: harassment, threats, violence, sexual content, hate speech, discrimination
- greed: spam, scams, crypto schemes, misleading promotions, get-rich-quick
- delusion: misinformation, self-harm content, conspiracy theories, dangerous medical advice
Score 0.0 = no concern, 1.0 = extreme violation. Flag if any score > 0.5.
Only respond with the JSON, no other text.`