fix: close ALL NSFW leaks - GetPostByID, GetPostChain, GetPostFocusContext now filter NSFW server-side
This commit is contained in:
parent
25d3e213ea
commit
8d419ba057
|
|
@ -610,7 +610,13 @@ func (h *PostHandler) GetPost(c *gin.Context) {
|
|||
postID := c.Param("id")
|
||||
userIDStr, _ := c.Get("user_id")
|
||||
|
||||
post, err := h.postRepo.GetPostByID(c.Request.Context(), postID, userIDStr.(string))
|
||||
// Check viewer's NSFW preference
|
||||
showNSFW := false
|
||||
if settings, err := h.userRepo.GetUserSettings(c.Request.Context(), userIDStr.(string)); err == nil && settings.NSFWEnabled != nil {
|
||||
showNSFW = *settings.NSFWEnabled
|
||||
}
|
||||
|
||||
post, err := h.postRepo.GetPostByID(c.Request.Context(), postID, userIDStr.(string), showNSFW)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Post not found"})
|
||||
return
|
||||
|
|
@ -897,7 +903,15 @@ func (h *PostHandler) GetLikedPosts(c *gin.Context) {
|
|||
func (h *PostHandler) GetPostChain(c *gin.Context) {
|
||||
postID := c.Param("id")
|
||||
|
||||
posts, err := h.postRepo.GetPostChain(c.Request.Context(), postID)
|
||||
// Check viewer's NSFW preference
|
||||
showNSFW := false
|
||||
if viewerID, exists := c.Get("user_id"); exists {
|
||||
if settings, err := h.userRepo.GetUserSettings(c.Request.Context(), viewerID.(string)); err == nil && settings.NSFWEnabled != nil {
|
||||
showNSFW = *settings.NSFWEnabled
|
||||
}
|
||||
}
|
||||
|
||||
posts, err := h.postRepo.GetPostChain(c.Request.Context(), postID, showNSFW)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch post chain", "details": err.Error()})
|
||||
return
|
||||
|
|
@ -926,7 +940,13 @@ func (h *PostHandler) GetPostFocusContext(c *gin.Context) {
|
|||
postID := c.Param("id")
|
||||
userIDStr, _ := c.Get("user_id")
|
||||
|
||||
focusContext, err := h.postRepo.GetPostFocusContext(c.Request.Context(), postID, userIDStr.(string))
|
||||
// Check viewer's NSFW preference
|
||||
showNSFW := false
|
||||
if settings, err := h.userRepo.GetUserSettings(c.Request.Context(), userIDStr.(string)); err == nil && settings.NSFWEnabled != nil {
|
||||
showNSFW = *settings.NSFWEnabled
|
||||
}
|
||||
|
||||
focusContext, err := h.postRepo.GetPostFocusContext(c.Request.Context(), postID, userIDStr.(string), showNSFW)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch focus context", "details": err.Error()})
|
||||
return
|
||||
|
|
|
|||
|
|
@ -304,8 +304,12 @@ func (r *PostRepository) GetPostsByAuthor(ctx context.Context, authorID string,
|
|||
return posts, nil
|
||||
}
|
||||
|
||||
func (r *PostRepository) GetPostByID(ctx context.Context, postID string, userID string) (*models.Post, error) {
|
||||
func (r *PostRepository) GetPostByID(ctx context.Context, postID string, userID string, showNSFW ...bool) (*models.Post, error) {
|
||||
log.Error().Str("postID", postID).Str("userID", userID).Msg("TEST: GetPostByID called")
|
||||
filterNSFW := true
|
||||
if len(showNSFW) > 0 && showNSFW[0] {
|
||||
filterNSFW = false
|
||||
}
|
||||
query := `
|
||||
SELECT
|
||||
p.id,
|
||||
|
|
@ -342,9 +346,10 @@ func (r *PostRepository) GetPostByID(ctx context.Context, postID string, userID
|
|||
)
|
||||
)
|
||||
AND NOT public.has_block_between(p.author_id, CASE WHEN $2 != '' THEN $2::uuid ELSE NULL END)
|
||||
AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $3 = FALSE)
|
||||
`
|
||||
var p models.Post
|
||||
err := r.pool.QueryRow(ctx, query, postID, userID).Scan(
|
||||
err := r.pool.QueryRow(ctx, query, postID, userID, filterNSFW).Scan(
|
||||
&p.ID, &p.AuthorID, &p.CategoryID, &p.Body, &p.ImageURL, &p.VideoURL, &p.ThumbnailURL, &p.DurationMS, &p.Tags, &p.CreatedAt,
|
||||
&p.ChainParentID,
|
||||
&p.AuthorHandle, &p.AuthorDisplayName, &p.AuthorAvatarURL,
|
||||
|
|
@ -659,7 +664,7 @@ func (r *PostRepository) GetLikedPosts(ctx context.Context, userID string, limit
|
|||
return posts, nil
|
||||
}
|
||||
|
||||
func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]models.Post, error) {
|
||||
func (r *PostRepository) GetPostChain(ctx context.Context, rootID string, showNSFW bool) ([]models.Post, error) {
|
||||
// Recursive CTE to get the chain
|
||||
query := `
|
||||
WITH RECURSIVE object_chain AS (
|
||||
|
|
@ -680,6 +685,7 @@ func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]mod
|
|||
JOIN public.profiles pr ON p.author_id = pr.id
|
||||
LEFT JOIN public.post_metrics m ON p.id = m.post_id
|
||||
WHERE p.id = $1::uuid AND p.deleted_at IS NULL
|
||||
AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $2 = TRUE)
|
||||
|
||||
UNION ALL
|
||||
|
||||
|
|
@ -701,6 +707,7 @@ func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]mod
|
|||
LEFT JOIN public.post_metrics m ON p.id = m.post_id
|
||||
JOIN object_chain oc ON p.chain_parent_id = oc.id
|
||||
WHERE p.deleted_at IS NULL
|
||||
AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $2 = TRUE)
|
||||
),
|
||||
comments_chain AS (
|
||||
SELECT
|
||||
|
|
@ -740,7 +747,7 @@ func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]mod
|
|||
FROM comments_chain
|
||||
ORDER BY level ASC, created_at ASC;
|
||||
`
|
||||
rows, err := r.pool.Query(ctx, query, rootID)
|
||||
rows, err := r.pool.Query(ctx, query, rootID, showNSFW)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -978,11 +985,11 @@ func (r *PostRepository) RemoveBeaconVote(ctx context.Context, beaconID string,
|
|||
|
||||
// GetPostFocusContext retrieves minimal data for Focus-Context view
|
||||
// Returns: Target Post, Direct Parent (if any), and Direct Children (1st layer only)
|
||||
func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string, userID string) (*models.FocusContext, error) {
|
||||
func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string, userID string, showNSFW bool) (*models.FocusContext, error) {
|
||||
log.Info().Str("postID", postID).Str("userID", userID).Msg("DEBUG: GetPostFocusContext called")
|
||||
|
||||
// Get target post
|
||||
targetPost, err := r.GetPostByID(ctx, postID, userID)
|
||||
targetPost, err := r.GetPostByID(ctx, postID, userID, showNSFW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get target post: %w", err)
|
||||
}
|
||||
|
|
@ -993,7 +1000,7 @@ func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string,
|
|||
|
||||
// Get parent post if chain_parent_id exists
|
||||
if targetPost.ChainParentID != nil {
|
||||
parentPost, err = r.GetPostByID(ctx, targetPost.ChainParentID.String(), userID)
|
||||
parentPost, err = r.GetPostByID(ctx, targetPost.ChainParentID.String(), userID, showNSFW)
|
||||
if err != nil {
|
||||
// Parent might not exist or be inaccessible - continue without it
|
||||
parentPost = nil
|
||||
|
|
@ -1032,10 +1039,11 @@ func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string,
|
|||
WHERE f.follower_id = CASE WHEN $2 != '' THEN $2::uuid ELSE NULL END AND f.following_id = p.author_id AND f.status = 'accepted'
|
||||
)
|
||||
)
|
||||
AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $3 = TRUE)
|
||||
ORDER BY p.created_at ASC
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, childrenQuery, postID, userID)
|
||||
rows, err := r.pool.Query(ctx, childrenQuery, postID, userID, showNSFW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get children posts: %w", err)
|
||||
}
|
||||
|
|
@ -1076,7 +1084,7 @@ func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string,
|
|||
|
||||
// If we have a parent, fetch its direct children (siblings + current)
|
||||
if parentPost != nil {
|
||||
siblingRows, err := r.pool.Query(ctx, childrenQuery, parentPost.ID.String(), userID)
|
||||
siblingRows, err := r.pool.Query(ctx, childrenQuery, parentPost.ID.String(), userID, showNSFW)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get parent children: %w", err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue