sojorn/go-backend/internal/services/link_preview_service.go

323 lines
8.1 KiB
Go

package services
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rs/zerolog/log"
)
// LinkPreview represents the OG metadata extracted from a URL.
type LinkPreview struct {
URL string `json:"link_preview_url"`
Title string `json:"link_preview_title"`
Description string `json:"link_preview_description"`
ImageURL string `json:"link_preview_image_url"`
SiteName string `json:"link_preview_site_name"`
}
// LinkPreviewService fetches and parses OpenGraph metadata from URLs.
type LinkPreviewService struct {
pool *pgxpool.Pool
httpClient *http.Client
}
func NewLinkPreviewService(pool *pgxpool.Pool) *LinkPreviewService {
return &LinkPreviewService{
pool: pool,
httpClient: &http.Client{
Timeout: 8 * time.Second,
CheckRedirect: func(req *http.Request, via []*http.Request) error {
if len(via) >= 5 {
return fmt.Errorf("too many redirects")
}
return nil
},
},
}
}
// blockedIPRanges are private/internal IP ranges that untrusted URLs must not resolve to.
var blockedIPRanges = []string{
"127.0.0.0/8",
"10.0.0.0/8",
"172.16.0.0/12",
"192.168.0.0/16",
"169.254.0.0/16",
"::1/128",
"fc00::/7",
"fe80::/10",
}
var blockedNets []*net.IPNet
func init() {
for _, cidr := range blockedIPRanges {
_, ipNet, err := net.ParseCIDR(cidr)
if err == nil {
blockedNets = append(blockedNets, ipNet)
}
}
}
func isPrivateIP(ip net.IP) bool {
for _, n := range blockedNets {
if n.Contains(ip) {
return true
}
}
return false
}
// ExtractFirstURL finds the first http/https URL in a text string.
func ExtractFirstURL(text string) string {
re := regexp.MustCompile(`https?://[^\s<>"')\]]+`)
match := re.FindString(text)
// Clean trailing punctuation that's not part of the URL
match = strings.TrimRight(match, ".,;:!?")
return match
}
// FetchPreview fetches OG metadata from a URL.
// If trusted is false, performs safety checks (no internal IPs, domain validation).
func (s *LinkPreviewService) FetchPreview(ctx context.Context, rawURL string, trusted bool) (*LinkPreview, error) {
if rawURL == "" {
return nil, fmt.Errorf("empty URL")
}
parsed, err := url.Parse(rawURL)
if err != nil {
return nil, fmt.Errorf("invalid URL: %w", err)
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return nil, fmt.Errorf("unsupported scheme: %s", parsed.Scheme)
}
// Safety checks for untrusted URLs
if !trusted {
if err := s.validateURL(parsed); err != nil {
return nil, fmt.Errorf("unsafe URL: %w", err)
}
}
req, err := http.NewRequestWithContext(ctx, "GET", rawURL, nil)
if err != nil {
return nil, err
}
req.Header.Set("User-Agent", "Mozilla/5.0 (compatible; Sojorn/1.0; +https://sojorn.net)")
req.Header.Set("Accept", "text/html")
resp, err := s.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("fetch failed: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
}
ct := resp.Header.Get("Content-Type")
if !strings.Contains(ct, "text/html") && !strings.Contains(ct, "application/xhtml") {
return nil, fmt.Errorf("not HTML: %s", ct)
}
// Read max 1MB
limited := io.LimitReader(resp.Body, 1*1024*1024)
body, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
preview := s.parseOGTags(string(body), rawURL)
if preview.Title == "" && preview.Description == "" && preview.ImageURL == "" {
return nil, fmt.Errorf("no OG metadata found")
}
preview.URL = rawURL
if preview.SiteName == "" {
preview.SiteName = parsed.Hostname()
}
return preview, nil
}
// validateURL checks that an untrusted URL doesn't point to internal resources.
func (s *LinkPreviewService) validateURL(u *url.URL) error {
host := u.Hostname()
// Block bare IPs for untrusted requests
if ip := net.ParseIP(host); ip != nil {
if isPrivateIP(ip) {
return fmt.Errorf("private IP not allowed")
}
}
// Resolve DNS and check all IPs
ips, err := net.LookupIP(host)
if err != nil {
return fmt.Errorf("DNS lookup failed: %w", err)
}
for _, ip := range ips {
if isPrivateIP(ip) {
return fmt.Errorf("resolves to private IP")
}
}
return nil
}
// parseOGTags extracts OpenGraph meta tags from raw HTML.
func (s *LinkPreviewService) parseOGTags(html string, sourceURL string) *LinkPreview {
preview := &LinkPreview{}
// Use regex to extract meta tags — lightweight, no dependency needed
metaRe := regexp.MustCompile(`(?i)<meta\s+[^>]*>`)
metas := metaRe.FindAllString(html, -1)
for _, tag := range metas {
prop := extractAttr(tag, "property")
if prop == "" {
prop = extractAttr(tag, "name")
}
content := extractAttr(tag, "content")
if content == "" {
continue
}
switch strings.ToLower(prop) {
case "og:title":
if preview.Title == "" {
preview.Title = content
}
case "og:description":
if preview.Description == "" {
preview.Description = content
}
case "og:image":
if preview.ImageURL == "" {
preview.ImageURL = resolveImageURL(content, sourceURL)
}
case "og:site_name":
if preview.SiteName == "" {
preview.SiteName = content
}
case "description":
// Fallback if no og:description
if preview.Description == "" {
preview.Description = content
}
}
}
// Fallback: try <title> tag if no og:title
if preview.Title == "" {
titleRe := regexp.MustCompile(`(?i)<title[^>]*>(.*?)</title>`)
if m := titleRe.FindStringSubmatch(html); len(m) > 1 {
preview.Title = strings.TrimSpace(m[1])
}
}
// Truncate long fields
if len(preview.Title) > 300 {
preview.Title = preview.Title[:300]
}
if len(preview.Description) > 500 {
preview.Description = preview.Description[:500]
}
return preview
}
// extractAttr pulls a named attribute value from a raw HTML tag string.
func extractAttr(tag string, name string) string {
// Match name="value" or name='value'
re := regexp.MustCompile(`(?i)\b` + regexp.QuoteMeta(name) + `\s*=\s*["']([^"']*?)["']`)
m := re.FindStringSubmatch(tag)
if len(m) > 1 {
return strings.TrimSpace(m[1])
}
return ""
}
// resolveImageURL makes relative image URLs absolute.
func resolveImageURL(imgURL string, sourceURL string) string {
if strings.HasPrefix(imgURL, "http://") || strings.HasPrefix(imgURL, "https://") {
return imgURL
}
base, err := url.Parse(sourceURL)
if err != nil {
return imgURL
}
ref, err := url.Parse(imgURL)
if err != nil {
return imgURL
}
return base.ResolveReference(ref).String()
}
// EnrichPostsWithLinkPreviews does a batch query to populate link_preview fields
// on a slice of posts. This avoids modifying every existing SELECT query.
func (s *LinkPreviewService) EnrichPostsWithLinkPreviews(ctx context.Context, postIDs []string) (map[string]*LinkPreview, error) {
if len(postIDs) == 0 {
return nil, nil
}
query := `
SELECT id::text, link_preview_url, link_preview_title,
link_preview_description, link_preview_image_url, link_preview_site_name
FROM public.posts
WHERE id = ANY($1::uuid[]) AND link_preview_url IS NOT NULL AND link_preview_url != ''
`
rows, err := s.pool.Query(ctx, query, postIDs)
if err != nil {
log.Warn().Err(err).Msg("Failed to fetch link previews for posts")
return nil, err
}
defer rows.Close()
result := make(map[string]*LinkPreview)
for rows.Next() {
var postID string
var lp LinkPreview
var title, desc, imgURL, siteName *string
if err := rows.Scan(&postID, &lp.URL, &title, &desc, &imgURL, &siteName); err != nil {
continue
}
if title != nil {
lp.Title = *title
}
if desc != nil {
lp.Description = *desc
}
if imgURL != nil {
lp.ImageURL = *imgURL
}
if siteName != nil {
lp.SiteName = *siteName
}
result[postID] = &lp
}
return result, nil
}
// SaveLinkPreview stores the link preview data for a post.
func (s *LinkPreviewService) SaveLinkPreview(ctx context.Context, postID string, lp *LinkPreview) error {
_, err := s.pool.Exec(ctx, `
UPDATE public.posts
SET link_preview_url = $2, link_preview_title = $3, link_preview_description = $4,
link_preview_image_url = $5, link_preview_site_name = $6
WHERE id = $1
`, postID, lp.URL, lp.Title, lp.Description, lp.ImageURL, lp.SiteName)
return err
}