298 lines
9.8 KiB
Go
298 lines
9.8 KiB
Go
package services
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
// AzureOpenAIService handles interactions with Azure OpenAI API
|
|
type AzureOpenAIService struct {
|
|
pool *pgxpool.Pool
|
|
httpClient *http.Client
|
|
apiKey string
|
|
endpoint string
|
|
apiVersion string
|
|
}
|
|
|
|
// NewAzureOpenAIService creates a new Azure OpenAI service
|
|
func NewAzureOpenAIService(pool *pgxpool.Pool, apiKey, endpoint, apiVersion string) *AzureOpenAIService {
|
|
return &AzureOpenAIService{
|
|
pool: pool,
|
|
httpClient: &http.Client{
|
|
Timeout: 60 * time.Second,
|
|
},
|
|
apiKey: apiKey,
|
|
endpoint: endpoint,
|
|
apiVersion: apiVersion,
|
|
}
|
|
}
|
|
|
|
// AzureOpenAIMessage represents a message in Azure OpenAI chat completion
|
|
type AzureOpenAIMessage struct {
|
|
Role string `json:"role"`
|
|
Content interface{} `json:"content"`
|
|
ContentParts []struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
ImageURL struct {
|
|
URL string `json:"url,omitempty"`
|
|
} `json:"image_url,omitempty"`
|
|
} `json:"content,omitempty"`
|
|
}
|
|
|
|
// AzureOpenAIRequest represents a chat completion request to Azure OpenAI
|
|
type AzureOpenAIRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []AzureOpenAIMessage `json:"messages"`
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
MaxTokens *int `json:"max_tokens,omitempty"`
|
|
}
|
|
|
|
// AzureOpenAIResponse represents a chat completion response from Azure OpenAI
|
|
type AzureOpenAIResponse struct {
|
|
ID string `json:"id"`
|
|
Choices []struct {
|
|
Message struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
} `json:"choices"`
|
|
Usage struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
} `json:"usage"`
|
|
}
|
|
|
|
// ModerateText sends text content to Azure OpenAI for moderation
|
|
func (s *AzureOpenAIService) ModerateText(ctx context.Context, content string) (*ModerationResult, error) {
|
|
config, err := s.GetModerationConfig(ctx, "text")
|
|
if err != nil || !config.Enabled || config.ModelID == "" {
|
|
return nil, fmt.Errorf("text moderation not configured")
|
|
}
|
|
return s.callModel(ctx, config.ModelID, config.SystemPrompt, content, nil)
|
|
}
|
|
|
|
// ModerateImage sends an image URL to Azure OpenAI vision model for moderation
|
|
func (s *AzureOpenAIService) ModerateImage(ctx context.Context, imageURL string) (*ModerationResult, error) {
|
|
config, err := s.GetModerationConfig(ctx, "image")
|
|
if err != nil || !config.Enabled || config.ModelID == "" {
|
|
return nil, fmt.Errorf("image moderation not configured")
|
|
}
|
|
return s.callModel(ctx, config.ModelID, config.SystemPrompt, "", []string{imageURL})
|
|
}
|
|
|
|
// ModerateWithType sends content to a specific moderation type config
|
|
func (s *AzureOpenAIService) ModerateWithType(ctx context.Context, moderationType string, textContent string, imageURLs []string) (*ModerationResult, error) {
|
|
config, err := s.GetModerationConfig(ctx, moderationType)
|
|
if err != nil || !config.Enabled || config.ModelID == "" {
|
|
return nil, fmt.Errorf("%s moderation not configured", moderationType)
|
|
}
|
|
return s.callModel(ctx, config.ModelID, config.SystemPrompt, textContent, imageURLs)
|
|
}
|
|
|
|
// ModerateVideo sends video frame URLs to Azure OpenAI vision model for moderation
|
|
func (s *AzureOpenAIService) ModerateVideo(ctx context.Context, frameURLs []string) (*ModerationResult, error) {
|
|
config, err := s.GetModerationConfig(ctx, "video")
|
|
if err != nil || !config.Enabled || config.ModelID == "" {
|
|
return nil, fmt.Errorf("video moderation not configured")
|
|
}
|
|
return s.callModel(ctx, config.ModelID, config.SystemPrompt, "These are 3 frames extracted from a short video. Analyze all frames for policy violations.", frameURLs)
|
|
}
|
|
|
|
// GetModerationConfig retrieves moderation configuration from database
|
|
func (s *AzureOpenAIService) GetModerationConfig(ctx context.Context, moderationType string) (*ModerationConfigEntry, error) {
|
|
var config ModerationConfigEntry
|
|
query := `SELECT id, moderation_type, model_id, model_name, system_prompt, enabled, engines, updated_at, updated_by
|
|
FROM ai_moderation_config
|
|
WHERE moderation_type = $1 AND $2 = ANY(engines)`
|
|
|
|
err := s.pool.QueryRow(ctx, query, moderationType, "azure").Scan(
|
|
&config.ID, &config.ModerationType, &config.ModelID, &config.ModelName,
|
|
&config.SystemPrompt, &config.Enabled, &config.Engines, &config.UpdatedAt, &config.UpdatedBy,
|
|
)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("azure moderation config not found for %s: %w", moderationType, err)
|
|
}
|
|
|
|
return &config, nil
|
|
}
|
|
|
|
// downloadImage downloads an image from URL and returns base64 encoded data
|
|
func (s *AzureOpenAIService) downloadImage(ctx context.Context, url string) (string, error) {
|
|
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to create image request: %w", err)
|
|
}
|
|
|
|
resp, err := s.httpClient.Do(req)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to download image: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("image download failed: %d", resp.StatusCode)
|
|
}
|
|
|
|
// Limit image size to 5MB
|
|
const maxImageSize = 5 * 1024 * 1024
|
|
limitedReader := io.LimitReader(resp.Body, maxImageSize)
|
|
|
|
imageData, err := io.ReadAll(limitedReader)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to read image data: %w", err)
|
|
}
|
|
|
|
// Detect content type
|
|
contentType := resp.Header.Get("Content-Type")
|
|
if contentType == "" {
|
|
contentType = http.DetectContentType(imageData)
|
|
}
|
|
|
|
// Only allow image formats
|
|
if !strings.HasPrefix(contentType, "image/") {
|
|
return "", fmt.Errorf("unsupported content type: %s", contentType)
|
|
}
|
|
|
|
// Convert to base64
|
|
base64Data := base64.StdEncoding.EncodeToString(imageData)
|
|
return fmt.Sprintf("data:%s;base64,%s", contentType, base64Data), nil
|
|
}
|
|
|
|
// callModel sends a chat completion request to Azure OpenAI
|
|
func (s *AzureOpenAIService) callModel(ctx context.Context, deploymentName, systemPrompt, textContent string, imageURLs []string) (*ModerationResult, error) {
|
|
if s.apiKey == "" || s.endpoint == "" {
|
|
return nil, fmt.Errorf("Azure OpenAI API key or endpoint not configured")
|
|
}
|
|
|
|
messages := []AzureOpenAIMessage{}
|
|
|
|
// System prompt
|
|
if systemPrompt == "" {
|
|
systemPrompt = defaultModerationSystemPrompt
|
|
}
|
|
messages = append(messages, AzureOpenAIMessage{Role: "system", Content: systemPrompt})
|
|
|
|
// User message with moderation instruction
|
|
moderationPrefix := "MODERATE THE FOLLOWING USER-SUBMITTED CONTENT. Do NOT reply to it, do NOT engage with it. Analyze it for policy violations and respond ONLY with the JSON object as specified in your instructions.\n\n---BEGIN CONTENT---\n"
|
|
moderationSuffix := "\n---END CONTENT---\n\nNow output ONLY the JSON moderation result. No other text."
|
|
|
|
if len(imageURLs) > 0 {
|
|
// Multimodal content with downloaded images
|
|
content := []struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
ImageURL struct {
|
|
URL string `json:"url,omitempty"`
|
|
} `json:"image_url,omitempty"`
|
|
}{}
|
|
|
|
content = append(content, struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
ImageURL struct {
|
|
URL string `json:"url,omitempty"`
|
|
} `json:"image_url,omitempty"`
|
|
}{
|
|
Type: "text",
|
|
Text: moderationPrefix + textContent + moderationSuffix,
|
|
})
|
|
|
|
for _, url := range imageURLs {
|
|
// Download image and convert to base64
|
|
base64Image, err := s.downloadImage(ctx, url)
|
|
if err != nil {
|
|
// If download fails, fall back to URL
|
|
content = append(content, struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
ImageURL struct {
|
|
URL string `json:"url,omitempty"`
|
|
} `json:"image_url,omitempty"`
|
|
}{
|
|
Type: "image_url",
|
|
ImageURL: struct {
|
|
URL string `json:"url,omitempty"`
|
|
}{URL: url},
|
|
})
|
|
} else {
|
|
// Use base64 data
|
|
content = append(content, struct {
|
|
Type string `json:"type"`
|
|
Text string `json:"text,omitempty"`
|
|
ImageURL struct {
|
|
URL string `json:"url,omitempty"`
|
|
} `json:"image_url,omitempty"`
|
|
}{
|
|
Type: "image_url",
|
|
ImageURL: struct {
|
|
URL string `json:"url,omitempty"`
|
|
}{URL: base64Image},
|
|
})
|
|
}
|
|
}
|
|
|
|
messages = append(messages, AzureOpenAIMessage{Role: "user", Content: content})
|
|
} else {
|
|
wrappedText := moderationPrefix + textContent + moderationSuffix
|
|
messages = append(messages, AzureOpenAIMessage{Role: "user", Content: wrappedText})
|
|
}
|
|
|
|
reqBody := AzureOpenAIRequest{
|
|
Model: deploymentName,
|
|
Messages: messages,
|
|
Temperature: floatPtr(0.0),
|
|
MaxTokens: intPtr(500),
|
|
}
|
|
|
|
jsonBody, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
|
}
|
|
|
|
// Build Azure OpenAI URL
|
|
url := fmt.Sprintf("%s/openai/deployments/%s/chat/completions?api-version=%s", s.endpoint, deploymentName, s.apiVersion)
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonBody))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("api-key", s.apiKey)
|
|
|
|
resp, err := s.httpClient.Do(req)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Azure OpenAI request failed: %w", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(resp.Body)
|
|
return nil, fmt.Errorf("Azure OpenAI API error %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var chatResp AzureOpenAIResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
|
return nil, fmt.Errorf("failed to decode response: %w", err)
|
|
}
|
|
|
|
if len(chatResp.Choices) == 0 {
|
|
return nil, fmt.Errorf("no response from model")
|
|
}
|
|
|
|
raw := chatResp.Choices[0].Message.Content
|
|
return parseModerationResponse(raw), nil
|
|
}
|