367 lines
12 KiB
Go
367 lines
12 KiB
Go
package services
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
// OpenRouterService handles interactions with the OpenRouter API
|
|
type OpenRouterService struct {
|
|
pool *pgxpool.Pool
|
|
httpClient *http.Client
|
|
apiKey string
|
|
|
|
// Cached model list
|
|
modelCache []OpenRouterModel
|
|
modelCacheMu sync.RWMutex
|
|
modelCacheTime time.Time
|
|
}
|
|
|
|
// OpenRouterModel represents a model available on OpenRouter
|
|
type OpenRouterModel struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description,omitempty"`
|
|
Pricing OpenRouterPricing `json:"pricing"`
|
|
ContextLength int `json:"context_length"`
|
|
Architecture map[string]any `json:"architecture,omitempty"`
|
|
TopProvider map[string]any `json:"top_provider,omitempty"`
|
|
PerRequestLimits map[string]any `json:"per_request_limits,omitempty"`
|
|
}
|
|
|
|
type OpenRouterPricing struct {
|
|
Prompt string `json:"prompt"`
|
|
Completion string `json:"completion"`
|
|
Image string `json:"image,omitempty"`
|
|
Request string `json:"request,omitempty"`
|
|
}
|
|
|
|
// ModerationConfigEntry represents a row in ai_moderation_config
|
|
type ModerationConfigEntry struct {
|
|
ID string `json:"id"`
|
|
ModerationType string `json:"moderation_type"`
|
|
ModelID string `json:"model_id"`
|
|
ModelName string `json:"model_name"`
|
|
SystemPrompt string `json:"system_prompt"`
|
|
Enabled bool `json:"enabled"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
UpdatedBy *string `json:"updated_by,omitempty"`
|
|
}
|
|
|
|
// OpenRouterChatMessage represents a message in a chat completion request
|
|
type OpenRouterChatMessage struct {
|
|
Role string `json:"role"`
|
|
Content any `json:"content"`
|
|
}
|
|
|
|
// OpenRouterChatRequest represents a chat completion request
|
|
type OpenRouterChatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []OpenRouterChatMessage `json:"messages"`
|
|
}
|
|
|
|
// OpenRouterChatResponse represents a chat completion response
|
|
type OpenRouterChatResponse struct {
|
|
ID string `json:"id"`
|
|
Choices []struct {
|
|
Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
func NewOpenRouterService(pool *pgxpool.Pool, apiKey string) *OpenRouterService {
|
|
return &OpenRouterService{
|
|
pool: pool,
|
|
apiKey: apiKey,
|
|
httpClient: &http.Client{
|
|
Timeout: 60 * time.Second,
|
|
},
|
|
}
|
|
}
|
|
|
|
// ListModels fetches available models from OpenRouter, with 1-hour cache
|
|
func (s *OpenRouterService) ListModels(ctx context.Context) ([]OpenRouterModel, error) {
|
|
s.modelCacheMu.RLock()
|
|
if len(s.modelCache) > 0 && time.Since(s.modelCacheTime) < time.Hour {
|
|
cached := s.modelCache
|
|
s.modelCacheMu.RUnlock()
|
|
return cached, nil
|
|
}
|
|
s.modelCacheMu.RUnlock()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://openrouter.ai/api/v1/models", nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
if s.apiKey != "" {
|
|
req.Header.Set("Authorization", "Bearer "+s.apiKey)
|
|
}
|
|
|
|
resp, err := s.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to fetch models: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("OpenRouter API error %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var result struct {
|
|
Data []OpenRouterModel `json:"data"`
|
|
}
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
return nil, fmt.Errorf("failed to decode models: %w", err)
|
|
}
|
|
|
|
s.modelCacheMu.Lock()
|
|
s.modelCache = result.Data
|
|
s.modelCacheTime = time.Now()
|
|
s.modelCacheMu.Unlock()
|
|
|
|
return result.Data, nil
|
|
}
|
|
|
|
// GetModerationConfigs returns all moderation type configurations
|
|
func (s *OpenRouterService) GetModerationConfigs(ctx context.Context) ([]ModerationConfigEntry, error) {
|
|
rows, err := s.pool.Query(ctx, `
|
|
SELECT id, moderation_type, model_id, model_name, system_prompt, enabled, updated_at, updated_by
|
|
FROM ai_moderation_config
|
|
ORDER BY moderation_type
|
|
`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to query configs: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var configs []ModerationConfigEntry
|
|
for rows.Next() {
|
|
var c ModerationConfigEntry
|
|
if err := rows.Scan(&c.ID, &c.ModerationType, &c.ModelID, &c.ModelName, &c.SystemPrompt, &c.Enabled, &c.UpdatedAt, &c.UpdatedBy); err != nil {
|
|
return nil, err
|
|
}
|
|
configs = append(configs, c)
|
|
}
|
|
return configs, nil
|
|
}
|
|
|
|
// GetModerationConfig returns config for a specific moderation type
|
|
func (s *OpenRouterService) GetModerationConfig(ctx context.Context, moderationType string) (*ModerationConfigEntry, error) {
|
|
var c ModerationConfigEntry
|
|
err := s.pool.QueryRow(ctx, `
|
|
SELECT id, moderation_type, model_id, model_name, system_prompt, enabled, updated_at, updated_by
|
|
FROM ai_moderation_config WHERE moderation_type = $1
|
|
`, moderationType).Scan(&c.ID, &c.ModerationType, &c.ModelID, &c.ModelName, &c.SystemPrompt, &c.Enabled, &c.UpdatedAt, &c.UpdatedBy)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &c, nil
|
|
}
|
|
|
|
// SetModerationConfig upserts a moderation config
|
|
func (s *OpenRouterService) SetModerationConfig(ctx context.Context, moderationType, modelID, modelName, systemPrompt string, enabled bool, updatedBy string) error {
|
|
_, err := s.pool.Exec(ctx, `
|
|
INSERT INTO ai_moderation_config (moderation_type, model_id, model_name, system_prompt, enabled, updated_by, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, NOW())
|
|
ON CONFLICT (moderation_type)
|
|
DO UPDATE SET model_id = $2, model_name = $3, system_prompt = $4, enabled = $5, updated_by = $6, updated_at = NOW()
|
|
`, moderationType, modelID, modelName, systemPrompt, enabled, updatedBy)
|
|
return err
|
|
}
|
|
|
|
// ModerateText sends text content to the configured model for moderation
|
|
func (s *OpenRouterService) ModerateText(ctx context.Context, content string) (*ModerationResult, error) {
|
|
config, err := s.GetModerationConfig(ctx, "text")
|
|
if err != nil || !config.Enabled || config.ModelID == "" {
|
|
return nil, fmt.Errorf("text moderation not configured")
|
|
}
|
|
return s.callModel(ctx, config.ModelID, config.SystemPrompt, content, nil)
|
|
}
|
|
|
|
// ModerateImage sends an image URL to a vision model for moderation
|
|
func (s *OpenRouterService) ModerateImage(ctx context.Context, imageURL string) (*ModerationResult, error) {
|
|
config, err := s.GetModerationConfig(ctx, "image")
|
|
if err != nil || !config.Enabled || config.ModelID == "" {
|
|
return nil, fmt.Errorf("image moderation not configured")
|
|
}
|
|
return s.callModel(ctx, config.ModelID, config.SystemPrompt, "", []string{imageURL})
|
|
}
|
|
|
|
// ModerateVideo sends video frame URLs to a vision model for moderation
|
|
func (s *OpenRouterService) ModerateVideo(ctx context.Context, frameURLs []string) (*ModerationResult, error) {
|
|
config, err := s.GetModerationConfig(ctx, "video")
|
|
if err != nil || !config.Enabled || config.ModelID == "" {
|
|
return nil, fmt.Errorf("video moderation not configured")
|
|
}
|
|
return s.callModel(ctx, config.ModelID, config.SystemPrompt, "These are 3 frames extracted from a short video. Analyze all frames for policy violations.", frameURLs)
|
|
}
|
|
|
|
// ModerationResult is the parsed response from OpenRouter moderation
|
|
type ModerationResult struct {
|
|
Flagged bool `json:"flagged"`
|
|
Reason string `json:"reason"`
|
|
Hate float64 `json:"hate"`
|
|
Greed float64 `json:"greed"`
|
|
Delusion float64 `json:"delusion"`
|
|
RawContent string `json:"raw_content"`
|
|
}
|
|
|
|
// callModel sends a chat completion request to OpenRouter
|
|
func (s *OpenRouterService) callModel(ctx context.Context, modelID, systemPrompt, textContent string, imageURLs []string) (*ModerationResult, error) {
|
|
if s.apiKey == "" {
|
|
return nil, fmt.Errorf("OpenRouter API key not configured")
|
|
}
|
|
|
|
messages := []OpenRouterChatMessage{}
|
|
|
|
// System prompt
|
|
if systemPrompt == "" {
|
|
systemPrompt = defaultModerationSystemPrompt
|
|
}
|
|
messages = append(messages, OpenRouterChatMessage{Role: "system", Content: systemPrompt})
|
|
|
|
// User message — text only or multimodal (text + images)
|
|
if len(imageURLs) > 0 {
|
|
// Multimodal content array
|
|
parts := []map[string]any{}
|
|
if textContent != "" {
|
|
parts = append(parts, map[string]any{"type": "text", "text": textContent})
|
|
}
|
|
for _, url := range imageURLs {
|
|
parts = append(parts, map[string]any{
|
|
"type": "image_url",
|
|
"image_url": map[string]string{"url": url},
|
|
})
|
|
}
|
|
messages = append(messages, OpenRouterChatMessage{Role: "user", Content: parts})
|
|
} else {
|
|
messages = append(messages, OpenRouterChatMessage{Role: "user", Content: textContent})
|
|
}
|
|
|
|
reqBody := OpenRouterChatRequest{
|
|
Model: modelID,
|
|
Messages: messages,
|
|
}
|
|
|
|
jsonBody, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", "https://openrouter.ai/api/v1/chat/completions", bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Authorization", "Bearer "+s.apiKey)
|
|
req.Header.Set("HTTP-Referer", "https://sojorn.net")
|
|
req.Header.Set("X-Title", "Sojorn Moderation")
|
|
|
|
resp, err := s.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("OpenRouter request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("OpenRouter API error %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var chatResp OpenRouterChatResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
if len(chatResp.Choices) == 0 {
|
|
return nil, fmt.Errorf("no response from model")
|
|
}
|
|
|
|
raw := chatResp.Choices[0].Message.Content
|
|
return parseModerationResponse(raw), nil
|
|
}
|
|
|
|
// parseModerationResponse tries to extract structured moderation data from model output
|
|
func parseModerationResponse(raw string) *ModerationResult {
|
|
result := &ModerationResult{RawContent: raw}
|
|
|
|
// Try to parse JSON from the response
|
|
// Models may wrap JSON in markdown code blocks
|
|
cleaned := raw
|
|
if idx := strings.Index(cleaned, "```json"); idx >= 0 {
|
|
cleaned = cleaned[idx+7:]
|
|
if end := strings.Index(cleaned, "```"); end >= 0 {
|
|
cleaned = cleaned[:end]
|
|
}
|
|
} else if idx := strings.Index(cleaned, "```"); idx >= 0 {
|
|
cleaned = cleaned[idx+3:]
|
|
if end := strings.Index(cleaned, "```"); end >= 0 {
|
|
cleaned = cleaned[:end]
|
|
}
|
|
}
|
|
cleaned = strings.TrimSpace(cleaned)
|
|
|
|
var parsed struct {
|
|
Flagged bool `json:"flagged"`
|
|
Reason string `json:"reason"`
|
|
Hate float64 `json:"hate"`
|
|
Greed float64 `json:"greed"`
|
|
Delusion float64 `json:"delusion"`
|
|
}
|
|
if err := json.Unmarshal([]byte(cleaned), &parsed); err == nil {
|
|
result.Flagged = parsed.Flagged
|
|
result.Reason = parsed.Reason
|
|
result.Hate = parsed.Hate
|
|
result.Greed = parsed.Greed
|
|
result.Delusion = parsed.Delusion
|
|
return result
|
|
}
|
|
|
|
// Fallback: check for keywords in raw text
|
|
lower := strings.ToLower(raw)
|
|
if strings.Contains(lower, "violation") || strings.Contains(lower, "inappropriate") || strings.Contains(lower, "flagged") {
|
|
result.Flagged = true
|
|
result.Reason = "Content flagged by AI moderation"
|
|
}
|
|
|
|
return result
|
|
}
|
|
|
|
const defaultModerationSystemPrompt = `You are a content moderation AI for Sojorn, a social media platform.
|
|
Analyze the provided content for policy violations.
|
|
|
|
Respond ONLY with a JSON object in this exact format:
|
|
{
|
|
"flagged": true/false,
|
|
"reason": "brief reason if flagged, empty string if not",
|
|
"hate": 0.0-1.0,
|
|
"greed": 0.0-1.0,
|
|
"delusion": 0.0-1.0
|
|
}
|
|
|
|
Scoring guide (Three Poisons framework):
|
|
- hate: harassment, threats, violence, sexual content, hate speech, discrimination
|
|
- greed: spam, scams, crypto schemes, misleading promotions, get-rich-quick
|
|
- delusion: misinformation, self-harm content, conspiracy theories, dangerous medical advice
|
|
|
|
Score 0.0 = no concern, 1.0 = extreme violation. Flag if any score > 0.5.
|
|
Only respond with the JSON, no other text.`
|