fix: close ALL NSFW leaks - GetPostByID, GetPostChain, GetPostFocusContext now filter NSFW server-side

This commit is contained in:
Patrick Britton 2026-02-08 00:27:23 -06:00
parent 25d3e213ea
commit 8d419ba057
2 changed files with 40 additions and 12 deletions

View file

@ -610,7 +610,13 @@ func (h *PostHandler) GetPost(c *gin.Context) {
postID := c.Param("id")
userIDStr, _ := c.Get("user_id")
post, err := h.postRepo.GetPostByID(c.Request.Context(), postID, userIDStr.(string))
// Check viewer's NSFW preference
showNSFW := false
if settings, err := h.userRepo.GetUserSettings(c.Request.Context(), userIDStr.(string)); err == nil && settings.NSFWEnabled != nil {
showNSFW = *settings.NSFWEnabled
}
post, err := h.postRepo.GetPostByID(c.Request.Context(), postID, userIDStr.(string), showNSFW)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Post not found"})
return
@ -897,7 +903,15 @@ func (h *PostHandler) GetLikedPosts(c *gin.Context) {
func (h *PostHandler) GetPostChain(c *gin.Context) {
postID := c.Param("id")
posts, err := h.postRepo.GetPostChain(c.Request.Context(), postID)
// Check viewer's NSFW preference
showNSFW := false
if viewerID, exists := c.Get("user_id"); exists {
if settings, err := h.userRepo.GetUserSettings(c.Request.Context(), viewerID.(string)); err == nil && settings.NSFWEnabled != nil {
showNSFW = *settings.NSFWEnabled
}
}
posts, err := h.postRepo.GetPostChain(c.Request.Context(), postID, showNSFW)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch post chain", "details": err.Error()})
return
@ -926,7 +940,13 @@ func (h *PostHandler) GetPostFocusContext(c *gin.Context) {
postID := c.Param("id")
userIDStr, _ := c.Get("user_id")
focusContext, err := h.postRepo.GetPostFocusContext(c.Request.Context(), postID, userIDStr.(string))
// Check viewer's NSFW preference
showNSFW := false
if settings, err := h.userRepo.GetUserSettings(c.Request.Context(), userIDStr.(string)); err == nil && settings.NSFWEnabled != nil {
showNSFW = *settings.NSFWEnabled
}
focusContext, err := h.postRepo.GetPostFocusContext(c.Request.Context(), postID, userIDStr.(string), showNSFW)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch focus context", "details": err.Error()})
return

View file

@ -304,8 +304,12 @@ func (r *PostRepository) GetPostsByAuthor(ctx context.Context, authorID string,
return posts, nil
}
func (r *PostRepository) GetPostByID(ctx context.Context, postID string, userID string) (*models.Post, error) {
func (r *PostRepository) GetPostByID(ctx context.Context, postID string, userID string, showNSFW ...bool) (*models.Post, error) {
log.Error().Str("postID", postID).Str("userID", userID).Msg("TEST: GetPostByID called")
filterNSFW := true
if len(showNSFW) > 0 && showNSFW[0] {
filterNSFW = false
}
query := `
SELECT
p.id,
@ -342,9 +346,10 @@ func (r *PostRepository) GetPostByID(ctx context.Context, postID string, userID
)
)
AND NOT public.has_block_between(p.author_id, CASE WHEN $2 != '' THEN $2::uuid ELSE NULL END)
AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $3 = FALSE)
`
var p models.Post
err := r.pool.QueryRow(ctx, query, postID, userID).Scan(
err := r.pool.QueryRow(ctx, query, postID, userID, filterNSFW).Scan(
&p.ID, &p.AuthorID, &p.CategoryID, &p.Body, &p.ImageURL, &p.VideoURL, &p.ThumbnailURL, &p.DurationMS, &p.Tags, &p.CreatedAt,
&p.ChainParentID,
&p.AuthorHandle, &p.AuthorDisplayName, &p.AuthorAvatarURL,
@ -659,7 +664,7 @@ func (r *PostRepository) GetLikedPosts(ctx context.Context, userID string, limit
return posts, nil
}
func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]models.Post, error) {
func (r *PostRepository) GetPostChain(ctx context.Context, rootID string, showNSFW bool) ([]models.Post, error) {
// Recursive CTE to get the chain
query := `
WITH RECURSIVE object_chain AS (
@ -680,6 +685,7 @@ func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]mod
JOIN public.profiles pr ON p.author_id = pr.id
LEFT JOIN public.post_metrics m ON p.id = m.post_id
WHERE p.id = $1::uuid AND p.deleted_at IS NULL
AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $2 = TRUE)
UNION ALL
@ -701,6 +707,7 @@ func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]mod
LEFT JOIN public.post_metrics m ON p.id = m.post_id
JOIN object_chain oc ON p.chain_parent_id = oc.id
WHERE p.deleted_at IS NULL
AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $2 = TRUE)
),
comments_chain AS (
SELECT
@ -740,7 +747,7 @@ func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]mod
FROM comments_chain
ORDER BY level ASC, created_at ASC;
`
rows, err := r.pool.Query(ctx, query, rootID)
rows, err := r.pool.Query(ctx, query, rootID, showNSFW)
if err != nil {
return nil, err
}
@ -978,11 +985,11 @@ func (r *PostRepository) RemoveBeaconVote(ctx context.Context, beaconID string,
// GetPostFocusContext retrieves minimal data for Focus-Context view
// Returns: Target Post, Direct Parent (if any), and Direct Children (1st layer only)
func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string, userID string) (*models.FocusContext, error) {
func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string, userID string, showNSFW bool) (*models.FocusContext, error) {
log.Info().Str("postID", postID).Str("userID", userID).Msg("DEBUG: GetPostFocusContext called")
// Get target post
targetPost, err := r.GetPostByID(ctx, postID, userID)
targetPost, err := r.GetPostByID(ctx, postID, userID, showNSFW)
if err != nil {
return nil, fmt.Errorf("failed to get target post: %w", err)
}
@ -993,7 +1000,7 @@ func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string,
// Get parent post if chain_parent_id exists
if targetPost.ChainParentID != nil {
parentPost, err = r.GetPostByID(ctx, targetPost.ChainParentID.String(), userID)
parentPost, err = r.GetPostByID(ctx, targetPost.ChainParentID.String(), userID, showNSFW)
if err != nil {
// Parent might not exist or be inaccessible - continue without it
parentPost = nil
@ -1032,10 +1039,11 @@ func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string,
WHERE f.follower_id = CASE WHEN $2 != '' THEN $2::uuid ELSE NULL END AND f.following_id = p.author_id AND f.status = 'accepted'
)
)
AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $3 = TRUE)
ORDER BY p.created_at ASC
`
rows, err := r.pool.Query(ctx, childrenQuery, postID, userID)
rows, err := r.pool.Query(ctx, childrenQuery, postID, userID, showNSFW)
if err != nil {
return nil, fmt.Errorf("failed to get children posts: %w", err)
}
@ -1076,7 +1084,7 @@ func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string,
// If we have a parent, fetch its direct children (siblings + current)
if parentPost != nil {
siblingRows, err := r.pool.Query(ctx, childrenQuery, parentPost.ID.String(), userID)
siblingRows, err := r.pool.Query(ctx, childrenQuery, parentPost.ID.String(), userID, showNSFW)
if err != nil {
return nil, fmt.Errorf("failed to get parent children: %w", err)
}