Ban enforcement: immediate session kill, IP logging, login/register/middleware checks
This commit is contained in:
parent
70fa1dddca
commit
f4701b0d24
|
|
@ -191,7 +191,7 @@ func main() {
|
|||
}
|
||||
|
||||
authorized := v1.Group("")
|
||||
authorized.Use(middleware.AuthMiddleware(cfg.JWTSecret))
|
||||
authorized.Use(middleware.AuthMiddleware(cfg.JWTSecret, dbPool))
|
||||
{
|
||||
authorized.GET("/profiles/:id", userHandler.GetProfile)
|
||||
authorized.GET("/profile", userHandler.GetProfile)
|
||||
|
|
@ -350,7 +350,7 @@ func main() {
|
|||
// Admin Panel API (requires auth + admin role)
|
||||
// ──────────────────────────────────────────────
|
||||
admin := r.Group("/api/v1/admin")
|
||||
admin.Use(middleware.AuthMiddleware(cfg.JWTSecret))
|
||||
admin.Use(middleware.AuthMiddleware(cfg.JWTSecret, dbPool))
|
||||
admin.Use(middleware.AdminMiddleware(dbPool))
|
||||
{
|
||||
// Dashboard
|
||||
|
|
|
|||
|
|
@ -75,6 +75,14 @@ func (h *AuthHandler) Register(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
// Check if this IP is banned (ban evasion prevention)
|
||||
ipBanned, _ := h.repo.IsIPBanned(c.Request.Context(), remoteIP)
|
||||
if ipBanned {
|
||||
log.Printf("[Auth] Registration blocked for banned IP: %s", remoteIP)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Registration is not available from this network."})
|
||||
return
|
||||
}
|
||||
|
||||
existingUser, err := h.repo.GetUserByEmail(c.Request.Context(), req.Email)
|
||||
if err == nil && existingUser != nil {
|
||||
c.JSON(http.StatusConflict, gin.H{"error": "Email already registered"})
|
||||
|
|
@ -178,6 +186,14 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
// Check if this IP is banned (ban evasion prevention)
|
||||
ipBanned, _ := h.repo.IsIPBanned(c.Request.Context(), remoteIP)
|
||||
if ipBanned {
|
||||
log.Printf("[Auth] Login blocked for banned IP: %s", remoteIP)
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Access is not available from this network."})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.repo.GetUserByEmail(c.Request.Context(), req.Email)
|
||||
if err != nil {
|
||||
log.Printf("[Auth] Login failed for %s: user not found", req.Email)
|
||||
|
|
@ -195,6 +211,14 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
|||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Email verification required", "code": "verify_email"})
|
||||
return
|
||||
}
|
||||
if user.Status == models.UserStatusBanned {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "This account has been permanently suspended for violating our community guidelines.", "code": "banned"})
|
||||
return
|
||||
}
|
||||
if user.Status == models.UserStatusSuspended {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Your account is temporarily suspended. Please try again later.", "code": "suspended"})
|
||||
return
|
||||
}
|
||||
if user.Status == models.UserStatusDeactivated {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Account deactivated"})
|
||||
return
|
||||
|
|
@ -359,6 +383,22 @@ func (h *AuthHandler) RefreshSession(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
// Check if user is banned/suspended before issuing new tokens
|
||||
rtUser, err := h.repo.GetUserByID(c.Request.Context(), rt.UserID.String())
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "User not found"})
|
||||
return
|
||||
}
|
||||
if rtUser.Status == models.UserStatusBanned {
|
||||
_ = h.repo.RevokeAllUserTokens(c.Request.Context(), rt.UserID.String())
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "This account has been permanently suspended.", "code": "banned"})
|
||||
return
|
||||
}
|
||||
if rtUser.Status == models.UserStatusSuspended {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Your account is temporarily suspended.", "code": "suspended"})
|
||||
return
|
||||
}
|
||||
|
||||
_ = h.repo.RevokeRefreshToken(c.Request.Context(), req.RefreshToken)
|
||||
|
||||
newAccessToken, err := h.generateToken(rt.UserID)
|
||||
|
|
|
|||
|
|
@ -61,8 +61,7 @@ func (h *PostHandler) CreateComment(c *gin.Context) {
|
|||
if h.contentFilter != nil {
|
||||
result := h.contentFilter.CheckContent(req.Body)
|
||||
if result.Blocked {
|
||||
// Record strike
|
||||
strikeCount, consequence, _ := h.contentFilter.RecordStrike(c.Request.Context(), userID, result.Category, req.Body)
|
||||
strikeCount, consequence, _ := h.contentFilter.RecordStrikeWithIP(c.Request.Context(), userID, result.Category, req.Body, c.ClientIP())
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{
|
||||
"error": result.Message,
|
||||
"blocked": true,
|
||||
|
|
@ -208,7 +207,7 @@ func (h *PostHandler) CreatePost(c *gin.Context) {
|
|||
if h.contentFilter != nil {
|
||||
result := h.contentFilter.CheckContent(req.Body)
|
||||
if result.Blocked {
|
||||
strikeCount, consequence, _ := h.contentFilter.RecordStrike(c.Request.Context(), userID, result.Category, req.Body)
|
||||
strikeCount, consequence, _ := h.contentFilter.RecordStrikeWithIP(c.Request.Context(), userID, result.Category, req.Body, c.ClientIP())
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{
|
||||
"error": result.Message,
|
||||
"blocked": true,
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
|
|
@ -37,7 +40,12 @@ func ParseToken(tokenString string, jwtSecret string) (string, jwt.MapClaims, er
|
|||
return userID, claims, nil
|
||||
}
|
||||
|
||||
func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
|
||||
func AuthMiddleware(jwtSecret string, pool ...*pgxpool.Pool) gin.HandlerFunc {
|
||||
var dbPool *pgxpool.Pool
|
||||
if len(pool) > 0 {
|
||||
dbPool = pool[0]
|
||||
}
|
||||
|
||||
return func(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
|
|
@ -63,6 +71,33 @@ func AuthMiddleware(jwtSecret string) gin.HandlerFunc {
|
|||
return
|
||||
}
|
||||
|
||||
// Check ban/suspend status from DB (immediate enforcement)
|
||||
if dbPool != nil {
|
||||
var status string
|
||||
var suspendedUntil *time.Time
|
||||
err := dbPool.QueryRow(context.Background(),
|
||||
`SELECT status, suspended_until FROM users WHERE id = $1::uuid`, userID,
|
||||
).Scan(&status, &suspendedUntil)
|
||||
if err == nil {
|
||||
if status == "banned" {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "This account has been permanently suspended.", "code": "banned"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if status == "suspended" {
|
||||
if suspendedUntil != nil && time.Now().After(*suspendedUntil) {
|
||||
// Suspension expired — reactivate
|
||||
dbPool.Exec(context.Background(),
|
||||
`UPDATE users SET status = 'active', suspended_until = NULL WHERE id = $1::uuid`, userID)
|
||||
} else {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Your account is temporarily suspended.", "code": "suspended"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store user ID and claims in context
|
||||
c.Set("user_id", userID)
|
||||
c.Set("claims", claims)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ const (
|
|||
UserStatusPending UserStatus = "pending"
|
||||
UserStatusActive UserStatus = "active"
|
||||
UserStatusDeactivated UserStatus = "deactivated"
|
||||
UserStatusBanned UserStatus = "banned"
|
||||
UserStatusSuspended UserStatus = "suspended"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
|
|
|
|||
|
|
@ -1302,3 +1302,21 @@ func (r *UserRepository) ExportUserData(ctx context.Context, userID string) (*Us
|
|||
|
||||
return export, nil
|
||||
}
|
||||
|
||||
// BanIP records an IP address as banned (used when a user is banned to prevent evasion)
|
||||
func (r *UserRepository) BanIP(ctx context.Context, ipAddress string, userID string, reason string) error {
|
||||
_, err := r.pool.Exec(ctx, `
|
||||
INSERT INTO banned_ips (ip_address, user_id, reason, banned_at)
|
||||
VALUES ($1, $2::uuid, $3, NOW())
|
||||
`, ipAddress, userID, reason)
|
||||
return err
|
||||
}
|
||||
|
||||
// IsIPBanned checks if an IP address has been banned
|
||||
func (r *UserRepository) IsIPBanned(ctx context.Context, ipAddress string) (bool, error) {
|
||||
var exists bool
|
||||
err := r.pool.QueryRow(ctx, `
|
||||
SELECT EXISTS(SELECT 1 FROM banned_ips WHERE ip_address = $1)
|
||||
`, ipAddress).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,6 +133,11 @@ func (cf *ContentFilter) CheckContent(text string) *ContentCheckResult {
|
|||
// 5 strikes: 7-day suspension
|
||||
// 7+ strikes: permanent ban
|
||||
func (cf *ContentFilter) RecordStrike(ctx context.Context, userID uuid.UUID, category, content string) (int, string, error) {
|
||||
return cf.RecordStrikeWithIP(ctx, userID, category, content, "")
|
||||
}
|
||||
|
||||
// RecordStrikeWithIP records a strike and logs the IP address for ban evasion prevention.
|
||||
func (cf *ContentFilter) RecordStrikeWithIP(ctx context.Context, userID uuid.UUID, category, content, clientIP string) (int, string, error) {
|
||||
// Insert strike
|
||||
_, err := cf.pool.Exec(ctx, `
|
||||
INSERT INTO content_strikes (user_id, category, content_snippet, created_at)
|
||||
|
|
@ -158,7 +163,16 @@ func (cf *ContentFilter) RecordStrike(ctx context.Context, userID uuid.UUID, cat
|
|||
case count >= 7:
|
||||
consequence = "ban"
|
||||
cf.pool.Exec(ctx, `UPDATE users SET status = 'banned' WHERE id = $1`, userID)
|
||||
fmt.Printf("Content filter: user %s BANNED (%d strikes)\n", userID, count)
|
||||
// Revoke ALL refresh tokens immediately so the user is logged out
|
||||
cf.pool.Exec(ctx, `UPDATE refresh_tokens SET revoked = true WHERE user_id = $1`, userID)
|
||||
// Log IP for ban evasion prevention
|
||||
if clientIP != "" {
|
||||
cf.pool.Exec(ctx, `
|
||||
INSERT INTO banned_ips (ip_address, user_id, reason, banned_at)
|
||||
VALUES ($1, $2, $3, NOW())
|
||||
`, clientIP, userID, fmt.Sprintf("auto-ban: %d strikes in 30 days", count))
|
||||
}
|
||||
fmt.Printf("Content filter: user %s BANNED (%d strikes), IP %s logged\n", userID, count, clientIP)
|
||||
case count >= 5:
|
||||
consequence = "suspend_7d"
|
||||
suspendUntil := time.Now().Add(7 * 24 * time.Hour)
|
||||
|
|
|
|||
11
go-backend/scripts/create_banned_ips.sql
Normal file
11
go-backend/scripts/create_banned_ips.sql
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
-- Banned IPs table for ban evasion prevention
|
||||
CREATE TABLE IF NOT EXISTS banned_ips (
|
||||
id SERIAL PRIMARY KEY,
|
||||
ip_address TEXT NOT NULL,
|
||||
user_id UUID REFERENCES users(id),
|
||||
reason TEXT,
|
||||
banned_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_banned_ips_address ON banned_ips (ip_address);
|
||||
CREATE INDEX IF NOT EXISTS idx_banned_ips_user ON banned_ips (user_id);
|
||||
Loading…
Reference in a new issue