Ban enforcement: immediate session kill, IP logging, login/register/middleware checks

This commit is contained in:
Patrick Britton 2026-02-06 12:09:02 -06:00
parent 70fa1dddca
commit f4701b0d24
8 changed files with 126 additions and 7 deletions

View file

@ -191,7 +191,7 @@ func main() {
} }
authorized := v1.Group("") authorized := v1.Group("")
authorized.Use(middleware.AuthMiddleware(cfg.JWTSecret)) authorized.Use(middleware.AuthMiddleware(cfg.JWTSecret, dbPool))
{ {
authorized.GET("/profiles/:id", userHandler.GetProfile) authorized.GET("/profiles/:id", userHandler.GetProfile)
authorized.GET("/profile", userHandler.GetProfile) authorized.GET("/profile", userHandler.GetProfile)
@ -350,7 +350,7 @@ func main() {
// Admin Panel API (requires auth + admin role) // Admin Panel API (requires auth + admin role)
// ────────────────────────────────────────────── // ──────────────────────────────────────────────
admin := r.Group("/api/v1/admin") admin := r.Group("/api/v1/admin")
admin.Use(middleware.AuthMiddleware(cfg.JWTSecret)) admin.Use(middleware.AuthMiddleware(cfg.JWTSecret, dbPool))
admin.Use(middleware.AdminMiddleware(dbPool)) admin.Use(middleware.AdminMiddleware(dbPool))
{ {
// Dashboard // Dashboard

View file

@ -75,6 +75,14 @@ func (h *AuthHandler) Register(c *gin.Context) {
return return
} }
// Check if this IP is banned (ban evasion prevention)
ipBanned, _ := h.repo.IsIPBanned(c.Request.Context(), remoteIP)
if ipBanned {
log.Printf("[Auth] Registration blocked for banned IP: %s", remoteIP)
c.JSON(http.StatusForbidden, gin.H{"error": "Registration is not available from this network."})
return
}
existingUser, err := h.repo.GetUserByEmail(c.Request.Context(), req.Email) existingUser, err := h.repo.GetUserByEmail(c.Request.Context(), req.Email)
if err == nil && existingUser != nil { if err == nil && existingUser != nil {
c.JSON(http.StatusConflict, gin.H{"error": "Email already registered"}) c.JSON(http.StatusConflict, gin.H{"error": "Email already registered"})
@ -178,6 +186,14 @@ func (h *AuthHandler) Login(c *gin.Context) {
return return
} }
// Check if this IP is banned (ban evasion prevention)
ipBanned, _ := h.repo.IsIPBanned(c.Request.Context(), remoteIP)
if ipBanned {
log.Printf("[Auth] Login blocked for banned IP: %s", remoteIP)
c.JSON(http.StatusForbidden, gin.H{"error": "Access is not available from this network."})
return
}
user, err := h.repo.GetUserByEmail(c.Request.Context(), req.Email) user, err := h.repo.GetUserByEmail(c.Request.Context(), req.Email)
if err != nil { if err != nil {
log.Printf("[Auth] Login failed for %s: user not found", req.Email) log.Printf("[Auth] Login failed for %s: user not found", req.Email)
@ -195,6 +211,14 @@ func (h *AuthHandler) Login(c *gin.Context) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Email verification required", "code": "verify_email"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Email verification required", "code": "verify_email"})
return return
} }
if user.Status == models.UserStatusBanned {
c.JSON(http.StatusForbidden, gin.H{"error": "This account has been permanently suspended for violating our community guidelines.", "code": "banned"})
return
}
if user.Status == models.UserStatusSuspended {
c.JSON(http.StatusForbidden, gin.H{"error": "Your account is temporarily suspended. Please try again later.", "code": "suspended"})
return
}
if user.Status == models.UserStatusDeactivated { if user.Status == models.UserStatusDeactivated {
c.JSON(http.StatusForbidden, gin.H{"error": "Account deactivated"}) c.JSON(http.StatusForbidden, gin.H{"error": "Account deactivated"})
return return
@ -359,6 +383,22 @@ func (h *AuthHandler) RefreshSession(c *gin.Context) {
return return
} }
// Check if user is banned/suspended before issuing new tokens
rtUser, err := h.repo.GetUserByID(c.Request.Context(), rt.UserID.String())
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
return
}
if rtUser.Status == models.UserStatusBanned {
_ = h.repo.RevokeAllUserTokens(c.Request.Context(), rt.UserID.String())
c.JSON(http.StatusForbidden, gin.H{"error": "This account has been permanently suspended.", "code": "banned"})
return
}
if rtUser.Status == models.UserStatusSuspended {
c.JSON(http.StatusForbidden, gin.H{"error": "Your account is temporarily suspended.", "code": "suspended"})
return
}
_ = h.repo.RevokeRefreshToken(c.Request.Context(), req.RefreshToken) _ = h.repo.RevokeRefreshToken(c.Request.Context(), req.RefreshToken)
newAccessToken, err := h.generateToken(rt.UserID) newAccessToken, err := h.generateToken(rt.UserID)

View file

@ -61,8 +61,7 @@ func (h *PostHandler) CreateComment(c *gin.Context) {
if h.contentFilter != nil { if h.contentFilter != nil {
result := h.contentFilter.CheckContent(req.Body) result := h.contentFilter.CheckContent(req.Body)
if result.Blocked { if result.Blocked {
// Record strike strikeCount, consequence, _ := h.contentFilter.RecordStrikeWithIP(c.Request.Context(), userID, result.Category, req.Body, c.ClientIP())
strikeCount, consequence, _ := h.contentFilter.RecordStrike(c.Request.Context(), userID, result.Category, req.Body)
c.JSON(http.StatusUnprocessableEntity, gin.H{ c.JSON(http.StatusUnprocessableEntity, gin.H{
"error": result.Message, "error": result.Message,
"blocked": true, "blocked": true,
@ -208,7 +207,7 @@ func (h *PostHandler) CreatePost(c *gin.Context) {
if h.contentFilter != nil { if h.contentFilter != nil {
result := h.contentFilter.CheckContent(req.Body) result := h.contentFilter.CheckContent(req.Body)
if result.Blocked { if result.Blocked {
strikeCount, consequence, _ := h.contentFilter.RecordStrike(c.Request.Context(), userID, result.Category, req.Body) strikeCount, consequence, _ := h.contentFilter.RecordStrikeWithIP(c.Request.Context(), userID, result.Category, req.Body, c.ClientIP())
c.JSON(http.StatusUnprocessableEntity, gin.H{ c.JSON(http.StatusUnprocessableEntity, gin.H{
"error": result.Message, "error": result.Message,
"blocked": true, "blocked": true,

View file

@ -1,12 +1,15 @@
package middleware package middleware
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
"time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@ -37,7 +40,12 @@ func ParseToken(tokenString string, jwtSecret string) (string, jwt.MapClaims, er
return userID, claims, nil return userID, claims, nil
} }
func AuthMiddleware(jwtSecret string) gin.HandlerFunc { func AuthMiddleware(jwtSecret string, pool ...*pgxpool.Pool) gin.HandlerFunc {
var dbPool *pgxpool.Pool
if len(pool) > 0 {
dbPool = pool[0]
}
return func(c *gin.Context) { return func(c *gin.Context) {
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
if authHeader == "" { if authHeader == "" {
@ -63,6 +71,33 @@ func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
return return
} }
// Check ban/suspend status from DB (immediate enforcement)
if dbPool != nil {
var status string
var suspendedUntil *time.Time
err := dbPool.QueryRow(context.Background(),
`SELECT status, suspended_until FROM users WHERE id = $1::uuid`, userID,
).Scan(&status, &suspendedUntil)
if err == nil {
if status == "banned" {
c.JSON(http.StatusForbidden, gin.H{"error": "This account has been permanently suspended.", "code": "banned"})
c.Abort()
return
}
if status == "suspended" {
if suspendedUntil != nil && time.Now().After(*suspendedUntil) {
// Suspension expired — reactivate
dbPool.Exec(context.Background(),
`UPDATE users SET status = 'active', suspended_until = NULL WHERE id = $1::uuid`, userID)
} else {
c.JSON(http.StatusForbidden, gin.H{"error": "Your account is temporarily suspended.", "code": "suspended"})
c.Abort()
return
}
}
}
}
// Store user ID and claims in context // Store user ID and claims in context
c.Set("user_id", userID) c.Set("user_id", userID)
c.Set("claims", claims) c.Set("claims", claims)

View file

@ -12,6 +12,8 @@ const (
UserStatusPending UserStatus = "pending" UserStatusPending UserStatus = "pending"
UserStatusActive UserStatus = "active" UserStatusActive UserStatus = "active"
UserStatusDeactivated UserStatus = "deactivated" UserStatusDeactivated UserStatus = "deactivated"
UserStatusBanned UserStatus = "banned"
UserStatusSuspended UserStatus = "suspended"
) )
type User struct { type User struct {

View file

@ -1302,3 +1302,21 @@ func (r *UserRepository) ExportUserData(ctx context.Context, userID string) (*Us
return export, nil return export, nil
} }
// BanIP records an IP address as banned (used when a user is banned to prevent evasion)
func (r *UserRepository) BanIP(ctx context.Context, ipAddress string, userID string, reason string) error {
_, err := r.pool.Exec(ctx, `
INSERT INTO banned_ips (ip_address, user_id, reason, banned_at)
VALUES ($1, $2::uuid, $3, NOW())
`, ipAddress, userID, reason)
return err
}
// IsIPBanned checks if an IP address has been banned
func (r *UserRepository) IsIPBanned(ctx context.Context, ipAddress string) (bool, error) {
var exists bool
err := r.pool.QueryRow(ctx, `
SELECT EXISTS(SELECT 1 FROM banned_ips WHERE ip_address = $1)
`, ipAddress).Scan(&exists)
return exists, err
}

View file

@ -133,6 +133,11 @@ func (cf *ContentFilter) CheckContent(text string) *ContentCheckResult {
// 5 strikes: 7-day suspension // 5 strikes: 7-day suspension
// 7+ strikes: permanent ban // 7+ strikes: permanent ban
func (cf *ContentFilter) RecordStrike(ctx context.Context, userID uuid.UUID, category, content string) (int, string, error) { func (cf *ContentFilter) RecordStrike(ctx context.Context, userID uuid.UUID, category, content string) (int, string, error) {
return cf.RecordStrikeWithIP(ctx, userID, category, content, "")
}
// RecordStrikeWithIP records a strike and logs the IP address for ban evasion prevention.
func (cf *ContentFilter) RecordStrikeWithIP(ctx context.Context, userID uuid.UUID, category, content, clientIP string) (int, string, error) {
// Insert strike // Insert strike
_, err := cf.pool.Exec(ctx, ` _, err := cf.pool.Exec(ctx, `
INSERT INTO content_strikes (user_id, category, content_snippet, created_at) INSERT INTO content_strikes (user_id, category, content_snippet, created_at)
@ -158,7 +163,16 @@ func (cf *ContentFilter) RecordStrike(ctx context.Context, userID uuid.UUID, cat
case count >= 7: case count >= 7:
consequence = "ban" consequence = "ban"
cf.pool.Exec(ctx, `UPDATE users SET status = 'banned' WHERE id = $1`, userID) cf.pool.Exec(ctx, `UPDATE users SET status = 'banned' WHERE id = $1`, userID)
fmt.Printf("Content filter: user %s BANNED (%d strikes)\n", userID, count) // Revoke ALL refresh tokens immediately so the user is logged out
cf.pool.Exec(ctx, `UPDATE refresh_tokens SET revoked = true WHERE user_id = $1`, userID)
// Log IP for ban evasion prevention
if clientIP != "" {
cf.pool.Exec(ctx, `
INSERT INTO banned_ips (ip_address, user_id, reason, banned_at)
VALUES ($1, $2, $3, NOW())
`, clientIP, userID, fmt.Sprintf("auto-ban: %d strikes in 30 days", count))
}
fmt.Printf("Content filter: user %s BANNED (%d strikes), IP %s logged\n", userID, count, clientIP)
case count >= 5: case count >= 5:
consequence = "suspend_7d" consequence = "suspend_7d"
suspendUntil := time.Now().Add(7 * 24 * time.Hour) suspendUntil := time.Now().Add(7 * 24 * time.Hour)

View file

@ -0,0 +1,11 @@
-- Banned IPs table for ban evasion prevention
CREATE TABLE IF NOT EXISTS banned_ips (
id SERIAL PRIMARY KEY,
ip_address TEXT NOT NULL,
user_id UUID REFERENCES users(id),
reason TEXT,
banned_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_banned_ips_address ON banned_ips (ip_address);
CREATE INDEX IF NOT EXISTS idx_banned_ips_user ON banned_ips (user_id);