sojorn/go-backend/internal/repository/backup_repository.go
2026-02-15 00:33:24 -06:00

502 lines
16 KiB
Go

package repository
import (
"context"
"crypto/rand"
"fmt"
"math/big"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/patbritton/sojorn-backend/internal/models"
)
type BackupRepository struct {
pool *pgxpool.Pool
}
func NewBackupRepository(pool *pgxpool.Pool) *BackupRepository {
return &BackupRepository{pool: pool}
}
// Sync Code Methods
func (r *BackupRepository) GenerateSyncCode(ctx context.Context, userID uuid.UUID, deviceName, deviceFingerprint string) (*models.SyncCode, error) {
// Generate 6-digit code
code := generateRandomCode(6)
expiresAt := time.Now().Add(5 * time.Minute)
// Check rate limit: max 5 codes per hour per user
var recentCodes int
err := r.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM sync_codes
WHERE user_id = $1 AND created_at > NOW() - INTERVAL '1 hour'
`, userID).Scan(&recentCodes)
if err != nil {
return nil, fmt.Errorf("failed to check rate limit: %w", err)
}
if recentCodes >= 5 {
return nil, fmt.Errorf("rate limit exceeded: too many codes generated")
}
// Insert sync code
var syncCode models.SyncCode
err = r.pool.QueryRow(ctx, `
INSERT INTO sync_codes (user_id, code, device_fingerprint, device_name, expires_at)
VALUES ($1, $2, $3, $4, $5)
RETURNING id, code, device_fingerprint, device_name, expires_at, used_at, attempts, created_at
`, userID, code, deviceFingerprint, deviceName, expiresAt).Scan(
&syncCode.ID, &syncCode.Code, &syncCode.DeviceFingerprint, &syncCode.DeviceName,
&syncCode.ExpiresAt, &syncCode.UsedAt, &syncCode.Attempts, &syncCode.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to generate sync code: %w", err)
}
return &syncCode, nil
}
func (r *BackupRepository) VerifySyncCode(ctx context.Context, code string, userID, requestingUserID uuid.UUID, deviceName, deviceFingerprint string) (*models.SyncCode, error) {
var syncCode models.SyncCode
err := r.pool.QueryRow(ctx, `
SELECT id, user_id, code, device_fingerprint, device_name, expires_at, used_at, attempts, created_at
FROM sync_codes
WHERE code = $1 AND used_at IS NULL
`, code).Scan(
&syncCode.ID, &syncCode.UserID, &syncCode.Code, &syncCode.DeviceFingerprint,
&syncCode.DeviceName, &syncCode.ExpiresAt, &syncCode.UsedAt, &syncCode.Attempts, &syncCode.CreatedAt,
)
if err != nil {
if err == pgx.ErrNoRows {
return nil, fmt.Errorf("invalid sync code")
}
return nil, fmt.Errorf("failed to verify sync code: %w", err)
}
// Check if expired
if time.Now().After(syncCode.ExpiresAt) {
return nil, fmt.Errorf("sync code has expired")
}
// Check attempts
if syncCode.Attempts >= 3 {
return nil, fmt.Errorf("too many attempts for this sync code")
}
// Increment attempts
_, err = r.pool.Exec(ctx, `
UPDATE sync_codes SET attempts = attempts + 1 WHERE id = $1
`, syncCode.ID)
if err != nil {
return nil, fmt.Errorf("failed to increment attempts: %w", err)
}
return &syncCode, nil
}
// Backup Methods
func (r *BackupRepository) UploadBackup(ctx context.Context, backup *models.CloudBackup) (uuid.UUID, error) {
// Keep only 3 most recent backups per user
_, err := r.pool.Exec(ctx, `
DELETE FROM cloud_backups
WHERE user_id = $1 AND id NOT IN (
SELECT id FROM cloud_backups
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT 3
)
`, backup.UserID)
if err != nil {
return uuid.Nil, fmt.Errorf("failed to clean old backups: %w", err)
}
// Insert new backup
var backupID uuid.UUID
err = r.pool.QueryRow(ctx, `
INSERT INTO cloud_backups (user_id, encrypted_blob, salt, nonce, mac, version, device_name, size_bytes)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id
`, backup.UserID, backup.EncryptedBlob, backup.Salt, backup.Nonce, backup.Mac,
backup.Version, backup.DeviceName, backup.SizeBytes).Scan(&backupID)
if err != nil {
return uuid.Nil, fmt.Errorf("failed to upload backup: %w", err)
}
return backupID, nil
}
func (r *BackupRepository) GetLatestBackup(ctx context.Context, userID uuid.UUID) (*models.CloudBackup, error) {
var backup models.CloudBackup
err := r.pool.QueryRow(ctx, `
SELECT id, user_id, encrypted_blob, salt, nonce, mac, version, device_name, size_bytes, created_at
FROM cloud_backups
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT 1
`, userID).Scan(
&backup.ID, &backup.UserID, &backup.EncryptedBlob, &backup.Salt, &backup.Nonce,
&backup.Mac, &backup.Version, &backup.DeviceName, &backup.SizeBytes, &backup.CreatedAt,
)
if err != nil {
if err == pgx.ErrNoRows {
return nil, fmt.Errorf("no backup found")
}
return nil, fmt.Errorf("failed to get latest backup: %w", err)
}
return &backup, nil
}
func (r *BackupRepository) GetBackup(ctx context.Context, backupID, userID uuid.UUID) (*models.CloudBackup, error) {
var backup models.CloudBackup
err := r.pool.QueryRow(ctx, `
SELECT id, user_id, encrypted_blob, salt, nonce, mac, version, device_name, size_bytes, created_at
FROM cloud_backups
WHERE id = $1 AND user_id = $2
`, backupID, userID).Scan(
&backup.ID, &backup.UserID, &backup.EncryptedBlob, &backup.Salt, &backup.Nonce,
&backup.Mac, &backup.Version, &backup.DeviceName, &backup.SizeBytes, &backup.CreatedAt,
)
if err != nil {
if err == pgx.ErrNoRows {
return nil, fmt.Errorf("backup not found")
}
return nil, fmt.Errorf("failed to get backup: %w", err)
}
return &backup, nil
}
func (r *BackupRepository) ListBackups(ctx context.Context, userID uuid.UUID) ([]models.CloudBackup, error) {
rows, err := r.pool.Query(ctx, `
SELECT id, user_id, encrypted_blob, salt, nonce, mac, version, device_name, size_bytes, created_at
FROM cloud_backups
WHERE user_id = $1
ORDER BY created_at DESC
`, userID)
if err != nil {
return nil, fmt.Errorf("failed to list backups: %w", err)
}
defer rows.Close()
var backups []models.CloudBackup
for rows.Next() {
var backup models.CloudBackup
err := rows.Scan(
&backup.ID, &backup.UserID, &backup.EncryptedBlob, &backup.Salt, &backup.Nonce,
&backup.Mac, &backup.Version, &backup.DeviceName, &backup.SizeBytes, &backup.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan backup: %w", err)
}
backups = append(backups, backup)
}
return backups, nil
}
func (r *BackupRepository) DeleteBackup(ctx context.Context, backupID, userID uuid.UUID) error {
_, err := r.pool.Exec(ctx, `
DELETE FROM cloud_backups WHERE id = $1 AND user_id = $2
`, backupID, userID)
if err != nil {
return fmt.Errorf("failed to delete backup: %w", err)
}
return nil
}
// Social Recovery Methods
func (r *BackupRepository) SetupSocialRecovery(ctx context.Context, userID uuid.UUID, guardianIDs []uuid.UUID, masterSecret []byte) error {
tx, err := r.pool.Begin(ctx)
if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Rollback(ctx)
// Clear existing guardians
_, err = tx.Exec(ctx, "DELETE FROM recovery_guardians WHERE user_id = $1", userID)
if err != nil {
return fmt.Errorf("failed to clear existing guardians: %w", err)
}
// Insert new guardians
for i, guardianID := range guardianIDs {
// In a real implementation, you'd encrypt the shard with guardian's public key
// For now, we'll just store the master secret (simplified)
_, err = tx.Exec(ctx, `
INSERT INTO recovery_guardians (user_id, guardian_user_id, shard_encrypted, shard_index)
VALUES ($1, $2, $3, $4)
`, userID, guardianID, masterSecret, i)
if err != nil {
return fmt.Errorf("failed to insert guardian: %w", err)
}
}
return tx.Commit(ctx)
}
func (r *BackupRepository) InitiateRecovery(ctx context.Context, userID uuid.UUID, method string) (*models.RecoverySession, error) {
var session models.RecoverySession
expiresAt := time.Now().Add(24 * time.Hour) // 24 hour expiry
err := r.pool.QueryRow(ctx, `
INSERT INTO recovery_sessions (user_id, method, shards_needed, expires_at)
VALUES ($1, $2, $3, $4)
RETURNING id, user_id, method, shards_received, shards_needed, status, expires_at, completed_at, created_at
`, userID, method, 3, expiresAt).Scan(
&session.ID, &session.UserID, &session.Method, &session.ShardsReceived,
&session.ShardsNeeded, &session.Status, &session.ExpiresAt, &session.CompletedAt, &session.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to initiate recovery: %w", err)
}
return &session, nil
}
func (r *BackupRepository) SubmitShard(ctx context.Context, sessionID, guardianUserID uuid.UUID, shardEncrypted []byte) (*models.RecoveryShardSubmission, error) {
// Get session details
var shardsNeeded int
err := r.pool.QueryRow(ctx, `
SELECT shards_needed FROM recovery_sessions WHERE id = $1 AND status = 'pending'
`, sessionID).Scan(&shardsNeeded)
if err != nil {
return nil, fmt.Errorf("invalid or expired recovery session: %w", err)
}
// Insert shard submission
var submission models.RecoveryShardSubmission
err = r.pool.QueryRow(ctx, `
INSERT INTO recovery_shard_submissions (session_id, guardian_user_id, shard_encrypted)
VALUES ($1, $2, $3)
RETURNING id, session_id, guardian_user_id, shard_encrypted, submitted_at, created_at
`, sessionID, guardianUserID, shardEncrypted).Scan(
&submission.ID, &submission.SessionID, &submission.GuardianUserID,
&submission.ShardEncrypted, &submission.SubmittedAt, &submission.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to submit shard: %w", err)
}
// Count total shards received
var shardsReceived int
err = r.pool.QueryRow(ctx, `
SELECT COUNT(*) FROM recovery_shard_submissions WHERE session_id = $1
`, sessionID).Scan(&shardsReceived)
if err != nil {
return nil, fmt.Errorf("failed to count shards: %w", err)
}
// Update session if threshold reached
if shardsReceived >= shardsNeeded {
_, err = r.pool.Exec(ctx, `
UPDATE recovery_sessions
SET shards_received = $1, status = 'in_progress'
WHERE id = $2
`, shardsReceived, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to update session: %w", err)
}
}
submission.ShardsReceived = shardsReceived
submission.ShardsNeeded = shardsNeeded
return &submission, nil
}
func (r *BackupRepository) CompleteRecovery(ctx context.Context, sessionID, userID uuid.UUID) ([]byte, error) {
// Get all shards for this session
rows, err := r.pool.Query(ctx, `
SELECT shard_encrypted FROM recovery_shard_submissions
WHERE session_id = $1
`, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to get shards: %w", err)
}
defer rows.Close()
var shards [][]byte
for rows.Next() {
var shard []byte
err := rows.Scan(&shard)
if err != nil {
return nil, fmt.Errorf("failed to scan shard: %w", err)
}
shards = append(shards, shard)
}
// Check if we have enough shards (simplified - in real implementation use Shamir's)
if len(shards) < 3 {
return nil, fmt.Errorf("insufficient shards to complete recovery")
}
// For this simplified implementation, we'll just return the first shard
// In a real implementation, you'd use Shamir's Secret Sharing to reconstruct
masterKey := shards[0]
// Mark session as completed
_, err = r.pool.Exec(ctx, `
UPDATE recovery_sessions
SET status = 'completed', completed_at = NOW()
WHERE id = $1
`, sessionID)
if err != nil {
return nil, fmt.Errorf("failed to complete session: %w", err)
}
return masterKey, nil
}
// Backup Preferences Methods
func (r *BackupRepository) GetBackupPreferences(ctx context.Context, userID uuid.UUID) (*models.BackupPreferences, error) {
var prefs models.BackupPreferences
err := r.pool.QueryRow(ctx, `
SELECT id, user_id, cloud_backup_enabled, auto_backup_enabled, backup_frequency_hours,
last_backup_at, backup_password_hash, backup_salt, created_at, updated_at
FROM backup_preferences
WHERE user_id = $1
`, userID).Scan(
&prefs.ID, &prefs.UserID, &prefs.CloudBackupEnabled, &prefs.AutoBackupEnabled,
&prefs.BackupFrequencyHours, &prefs.LastBackupAt, &prefs.BackupPasswordHash,
&prefs.BackupSalt, &prefs.CreatedAt, &prefs.UpdatedAt,
)
if err != nil {
if err == pgx.ErrNoRows {
// Create default preferences
return r.createDefaultPreferences(ctx, userID)
}
return nil, fmt.Errorf("failed to get backup preferences: %w", err)
}
return &prefs, nil
}
func (r *BackupRepository) createDefaultPreferences(ctx context.Context, userID uuid.UUID) (*models.BackupPreferences, error) {
var prefs models.BackupPreferences
err := r.pool.QueryRow(ctx, `
INSERT INTO backup_preferences (user_id)
VALUES ($1)
RETURNING id, user_id, cloud_backup_enabled, auto_backup_enabled, backup_frequency_hours,
last_backup_at, backup_password_hash, backup_salt, created_at, updated_at
`, userID).Scan(
&prefs.ID, &prefs.UserID, &prefs.CloudBackupEnabled, &prefs.AutoBackupEnabled,
&prefs.BackupFrequencyHours, &prefs.LastBackupAt, &prefs.BackupPasswordHash,
&prefs.BackupSalt, &prefs.CreatedAt, &prefs.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create default preferences: %w", err)
}
return &prefs, nil
}
func (r *BackupRepository) UpdateBackupPreferences(ctx context.Context, userID uuid.UUID, req models.UpdateBackupPreferencesRequest) error {
// In a real implementation, you'd hash the backup password with Argon2id
var passwordHash *string
var salt *[]byte
if req.BackupPassword != "" {
// Generate salt
saltBytes := make([]byte, 32)
if _, err := rand.Read(saltBytes); err != nil {
return fmt.Errorf("failed to generate salt: %w", err)
}
salt = &saltBytes
// For now, just store the password as hash (simplified)
hash := req.BackupPassword // In real implementation: Argon2id hash
passwordHash = &hash
}
_, err := r.pool.Exec(ctx, `
UPDATE backup_preferences
SET cloud_backup_enabled = $1, auto_backup_enabled = $2, backup_frequency_hours = $3,
backup_password_hash = $4, backup_salt = $5, updated_at = NOW()
WHERE user_id = $6
`, req.CloudBackupEnabled, req.AutoBackupEnabled, req.BackupFrequencyHours,
passwordHash, salt, userID)
if err != nil {
return fmt.Errorf("failed to update backup preferences: %w", err)
}
return nil
}
func (r *BackupRepository) UpdateLastBackupTime(ctx context.Context, userID uuid.UUID) error {
_, err := r.pool.Exec(ctx, `
UPDATE backup_preferences SET last_backup_at = NOW() WHERE user_id = $1
`, userID)
if err != nil {
return fmt.Errorf("failed to update last backup time: %w", err)
}
return nil
}
// Device Management Methods
func (r *BackupRepository) RegisterDevice(ctx context.Context, userID uuid.UUID, deviceName, deviceFingerprint, deviceType string) (*models.UserDevice, error) {
var device models.UserDevice
err := r.pool.QueryRow(ctx, `
INSERT INTO user_devices (user_id, device_fingerprint, device_name, device_type, last_seen_at)
VALUES ($1, $2, $3, $4, NOW())
ON CONFLICT (user_id, device_fingerprint)
DO UPDATE SET device_name = EXCLUDED.device_name, last_seen_at = NOW(), is_active = true
RETURNING id, user_id, device_fingerprint, device_name, device_type, last_seen_at, is_active, created_at
`, userID, deviceFingerprint, deviceName, deviceType).Scan(
&device.ID, &device.UserID, &device.DeviceFingerprint, &device.DeviceName,
&device.DeviceType, &device.LastSeenAt, &device.IsActive, &device.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to register device: %w", err)
}
return &device, nil
}
func (r *BackupRepository) GetUserDevices(ctx context.Context, userID uuid.UUID) ([]models.UserDevice, error) {
rows, err := r.pool.Query(ctx, `
SELECT id, user_id, device_fingerprint, device_name, device_type, last_seen_at, is_active, created_at
FROM user_devices
WHERE user_id = $1
ORDER BY last_seen_at DESC
`, userID)
if err != nil {
return nil, fmt.Errorf("failed to get user devices: %w", err)
}
defer rows.Close()
var devices []models.UserDevice
for rows.Next() {
var device models.UserDevice
err := rows.Scan(
&device.ID, &device.UserID, &device.DeviceFingerprint, &device.DeviceName,
&device.DeviceType, &device.LastSeenAt, &device.IsActive, &device.CreatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to scan device: %w", err)
}
devices = append(devices, device)
}
return devices, nil
}
// Helper function to generate random numeric code
func generateRandomCode(length int) string {
code := make([]byte, length)
for i := range code {
digit, err := rand.Int(rand.Reader, big.NewInt(10))
if err != nil {
code[i] = '0'
continue
}
code[i] = byte(digit.Int64()) + '0'
}
return string(code)
}