diff --git a/go-backend/internal/handlers/post_handler.go b/go-backend/internal/handlers/post_handler.go index b5c1126..41ef4b0 100644 --- a/go-backend/internal/handlers/post_handler.go +++ b/go-backend/internal/handlers/post_handler.go @@ -610,7 +610,13 @@ func (h *PostHandler) GetPost(c *gin.Context) { postID := c.Param("id") userIDStr, _ := c.Get("user_id") - post, err := h.postRepo.GetPostByID(c.Request.Context(), postID, userIDStr.(string)) + // Check viewer's NSFW preference + showNSFW := false + if settings, err := h.userRepo.GetUserSettings(c.Request.Context(), userIDStr.(string)); err == nil && settings.NSFWEnabled != nil { + showNSFW = *settings.NSFWEnabled + } + + post, err := h.postRepo.GetPostByID(c.Request.Context(), postID, userIDStr.(string), showNSFW) if err != nil { c.JSON(http.StatusNotFound, gin.H{"error": "Post not found"}) return @@ -897,7 +903,15 @@ func (h *PostHandler) GetLikedPosts(c *gin.Context) { func (h *PostHandler) GetPostChain(c *gin.Context) { postID := c.Param("id") - posts, err := h.postRepo.GetPostChain(c.Request.Context(), postID) + // Check viewer's NSFW preference + showNSFW := false + if viewerID, exists := c.Get("user_id"); exists { + if settings, err := h.userRepo.GetUserSettings(c.Request.Context(), viewerID.(string)); err == nil && settings.NSFWEnabled != nil { + showNSFW = *settings.NSFWEnabled + } + } + + posts, err := h.postRepo.GetPostChain(c.Request.Context(), postID, showNSFW) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch post chain", "details": err.Error()}) return @@ -926,7 +940,13 @@ func (h *PostHandler) GetPostFocusContext(c *gin.Context) { postID := c.Param("id") userIDStr, _ := c.Get("user_id") - focusContext, err := h.postRepo.GetPostFocusContext(c.Request.Context(), postID, userIDStr.(string)) + // Check viewer's NSFW preference + showNSFW := false + if settings, err := h.userRepo.GetUserSettings(c.Request.Context(), userIDStr.(string)); err == nil && settings.NSFWEnabled != nil { + showNSFW = *settings.NSFWEnabled + } + + focusContext, err := h.postRepo.GetPostFocusContext(c.Request.Context(), postID, userIDStr.(string), showNSFW) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to fetch focus context", "details": err.Error()}) return diff --git a/go-backend/internal/repository/post_repository.go b/go-backend/internal/repository/post_repository.go index 4cf5fd0..87bacc5 100644 --- a/go-backend/internal/repository/post_repository.go +++ b/go-backend/internal/repository/post_repository.go @@ -304,8 +304,12 @@ func (r *PostRepository) GetPostsByAuthor(ctx context.Context, authorID string, return posts, nil } -func (r *PostRepository) GetPostByID(ctx context.Context, postID string, userID string) (*models.Post, error) { +func (r *PostRepository) GetPostByID(ctx context.Context, postID string, userID string, showNSFW ...bool) (*models.Post, error) { log.Error().Str("postID", postID).Str("userID", userID).Msg("TEST: GetPostByID called") + filterNSFW := true + if len(showNSFW) > 0 && showNSFW[0] { + filterNSFW = false + } query := ` SELECT p.id, @@ -342,9 +346,10 @@ func (r *PostRepository) GetPostByID(ctx context.Context, postID string, userID ) ) AND NOT public.has_block_between(p.author_id, CASE WHEN $2 != '' THEN $2::uuid ELSE NULL END) + AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $3 = FALSE) ` var p models.Post - err := r.pool.QueryRow(ctx, query, postID, userID).Scan( + err := r.pool.QueryRow(ctx, query, postID, userID, filterNSFW).Scan( &p.ID, &p.AuthorID, &p.CategoryID, &p.Body, &p.ImageURL, &p.VideoURL, &p.ThumbnailURL, &p.DurationMS, &p.Tags, &p.CreatedAt, &p.ChainParentID, &p.AuthorHandle, &p.AuthorDisplayName, &p.AuthorAvatarURL, @@ -659,7 +664,7 @@ func (r *PostRepository) GetLikedPosts(ctx context.Context, userID string, limit return posts, nil } -func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]models.Post, error) { +func (r *PostRepository) GetPostChain(ctx context.Context, rootID string, showNSFW bool) ([]models.Post, error) { // Recursive CTE to get the chain query := ` WITH RECURSIVE object_chain AS ( @@ -680,6 +685,7 @@ func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]mod JOIN public.profiles pr ON p.author_id = pr.id LEFT JOIN public.post_metrics m ON p.id = m.post_id WHERE p.id = $1::uuid AND p.deleted_at IS NULL + AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $2 = TRUE) UNION ALL @@ -701,6 +707,7 @@ func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]mod LEFT JOIN public.post_metrics m ON p.id = m.post_id JOIN object_chain oc ON p.chain_parent_id = oc.id WHERE p.deleted_at IS NULL + AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $2 = TRUE) ), comments_chain AS ( SELECT @@ -740,7 +747,7 @@ func (r *PostRepository) GetPostChain(ctx context.Context, rootID string) ([]mod FROM comments_chain ORDER BY level ASC, created_at ASC; ` - rows, err := r.pool.Query(ctx, query, rootID) + rows, err := r.pool.Query(ctx, query, rootID, showNSFW) if err != nil { return nil, err } @@ -978,11 +985,11 @@ func (r *PostRepository) RemoveBeaconVote(ctx context.Context, beaconID string, // GetPostFocusContext retrieves minimal data for Focus-Context view // Returns: Target Post, Direct Parent (if any), and Direct Children (1st layer only) -func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string, userID string) (*models.FocusContext, error) { +func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string, userID string, showNSFW bool) (*models.FocusContext, error) { log.Info().Str("postID", postID).Str("userID", userID).Msg("DEBUG: GetPostFocusContext called") // Get target post - targetPost, err := r.GetPostByID(ctx, postID, userID) + targetPost, err := r.GetPostByID(ctx, postID, userID, showNSFW) if err != nil { return nil, fmt.Errorf("failed to get target post: %w", err) } @@ -993,7 +1000,7 @@ func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string, // Get parent post if chain_parent_id exists if targetPost.ChainParentID != nil { - parentPost, err = r.GetPostByID(ctx, targetPost.ChainParentID.String(), userID) + parentPost, err = r.GetPostByID(ctx, targetPost.ChainParentID.String(), userID, showNSFW) if err != nil { // Parent might not exist or be inaccessible - continue without it parentPost = nil @@ -1032,10 +1039,11 @@ func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string, WHERE f.follower_id = CASE WHEN $2 != '' THEN $2::uuid ELSE NULL END AND f.following_id = p.author_id AND f.status = 'accepted' ) ) + AND (COALESCE(p.is_nsfw, FALSE) = FALSE OR $3 = TRUE) ORDER BY p.created_at ASC ` - rows, err := r.pool.Query(ctx, childrenQuery, postID, userID) + rows, err := r.pool.Query(ctx, childrenQuery, postID, userID, showNSFW) if err != nil { return nil, fmt.Errorf("failed to get children posts: %w", err) } @@ -1076,7 +1084,7 @@ func (r *PostRepository) GetPostFocusContext(ctx context.Context, postID string, // If we have a parent, fetch its direct children (siblings + current) if parentPost != nil { - siblingRows, err := r.pool.Query(ctx, childrenQuery, parentPost.ID.String(), userID) + siblingRows, err := r.pool.Query(ctx, childrenQuery, parentPost.ID.String(), userID, showNSFW) if err != nil { return nil, fmt.Errorf("failed to get parent children: %w", err) }