package auth import ( "context" "crypto/rand" "database/sql" "encoding/hex" "errors" "time" "golang.org/x/crypto/bcrypt" ) var ( ErrUserNotFound = errors.New("user not found") ErrEmailExists = errors.New("email already exists") ErrInvalidPassword = errors.New("invalid password") ErrTokenExpired = errors.New("token expired") ErrTokenInvalid = errors.New("invalid token") ErrWeakPassword = errors.New("password too weak") ) type Repository struct { db *sql.DB } func NewRepository(db *sql.DB) *Repository { return &Repository{db: db} } func (r *Repository) RunMigrations(ctx context.Context) error { migrations := []string{ `CREATE TABLE IF NOT EXISTS auth_users ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), email VARCHAR(255) UNIQUE NOT NULL, password_hash VARCHAR(255), name VARCHAR(255) NOT NULL, avatar TEXT, role VARCHAR(50) DEFAULT 'user', tier VARCHAR(50) DEFAULT 'free', balance DECIMAL(12,2) DEFAULT 0, email_verified BOOLEAN DEFAULT FALSE, provider VARCHAR(50) DEFAULT 'local', provider_id VARCHAR(255), last_login_at TIMESTAMPTZ, created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW() )`, `CREATE INDEX IF NOT EXISTS idx_auth_users_email ON auth_users(email)`, `CREATE INDEX IF NOT EXISTS idx_auth_users_provider ON auth_users(provider, provider_id)`, `CREATE TABLE IF NOT EXISTS refresh_tokens ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), user_id UUID NOT NULL REFERENCES auth_users(id) ON DELETE CASCADE, token VARCHAR(255) UNIQUE NOT NULL, user_agent TEXT, ip VARCHAR(50), expires_at TIMESTAMPTZ NOT NULL, created_at TIMESTAMPTZ DEFAULT NOW() )`, `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_user ON refresh_tokens(user_id)`, `CREATE INDEX IF NOT EXISTS idx_refresh_tokens_token ON refresh_tokens(token)`, `CREATE TABLE IF NOT EXISTS password_reset_tokens ( id UUID PRIMARY KEY DEFAULT gen_random_uuid(), user_id UUID NOT NULL REFERENCES auth_users(id) ON DELETE CASCADE, token VARCHAR(255) UNIQUE NOT NULL, expires_at TIMESTAMPTZ NOT NULL, used BOOLEAN DEFAULT FALSE, created_at TIMESTAMPTZ DEFAULT NOW() )`, `CREATE INDEX IF NOT EXISTS idx_password_reset_tokens ON password_reset_tokens(token)`, `ALTER TABLE auth_users ADD COLUMN IF NOT EXISTS balance DECIMAL(12,2) DEFAULT 0`, } for _, m := range migrations { if _, err := r.db.ExecContext(ctx, m); err != nil { return err } } return nil } func (r *Repository) CreateUser(ctx context.Context, email, password, name string) (*User, error) { var exists bool err := r.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM auth_users WHERE email = $1)", email).Scan(&exists) if err != nil { return nil, err } if exists { return nil, ErrEmailExists } if len(password) < 8 { return nil, ErrWeakPassword } hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return nil, err } user := &User{ Email: email, PasswordHash: string(hash), Name: name, Role: string(RoleUser), Tier: string(TierFree), Provider: ProviderLocal, } query := ` INSERT INTO auth_users (email, password_hash, name, role, tier, provider) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, created_at, updated_at ` err = r.db.QueryRowContext(ctx, query, user.Email, user.PasswordHash, user.Name, user.Role, user.Tier, user.Provider, ).Scan(&user.ID, &user.CreatedAt, &user.UpdatedAt) if err != nil { return nil, err } return user, nil } func (r *Repository) GetUserByEmail(ctx context.Context, email string) (*User, error) { query := ` SELECT id, email, password_hash, name, avatar, role, tier, balance, email_verified, provider, provider_id, last_login_at, created_at, updated_at FROM auth_users WHERE email = $1 ` user := &User{} var avatar, providerID sql.NullString var lastLoginTime sql.NullTime var balance sql.NullFloat64 err := r.db.QueryRowContext(ctx, query, email).Scan( &user.ID, &user.Email, &user.PasswordHash, &user.Name, &avatar, &user.Role, &user.Tier, &balance, &user.EmailVerified, &user.Provider, &providerID, &lastLoginTime, &user.CreatedAt, &user.UpdatedAt, ) if err == sql.ErrNoRows { return nil, ErrUserNotFound } if err != nil { return nil, err } if avatar.Valid { user.Avatar = avatar.String } if providerID.Valid { user.ProviderID = providerID.String } if lastLoginTime.Valid { user.LastLoginAt = lastLoginTime.Time } if balance.Valid { user.Balance = balance.Float64 } return user, nil } func (r *Repository) GetUserByID(ctx context.Context, id string) (*User, error) { query := ` SELECT id, email, password_hash, name, avatar, role, tier, balance, email_verified, provider, provider_id, last_login_at, created_at, updated_at FROM auth_users WHERE id = $1 ` user := &User{} var avatar, providerID sql.NullString var lastLoginTime sql.NullTime var balance sql.NullFloat64 err := r.db.QueryRowContext(ctx, query, id).Scan( &user.ID, &user.Email, &user.PasswordHash, &user.Name, &avatar, &user.Role, &user.Tier, &balance, &user.EmailVerified, &user.Provider, &providerID, &lastLoginTime, &user.CreatedAt, &user.UpdatedAt, ) if err == sql.ErrNoRows { return nil, ErrUserNotFound } if err != nil { return nil, err } if avatar.Valid { user.Avatar = avatar.String } if providerID.Valid { user.ProviderID = providerID.String } if lastLoginTime.Valid { user.LastLoginAt = lastLoginTime.Time } if balance.Valid { user.Balance = balance.Float64 } return user, nil } func (r *Repository) ValidatePassword(ctx context.Context, email, password string) (*User, error) { user, err := r.GetUserByEmail(ctx, email) if err != nil { return nil, err } if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(password)); err != nil { return nil, ErrInvalidPassword } r.db.ExecContext(ctx, "UPDATE auth_users SET last_login_at = NOW() WHERE id = $1", user.ID) return user, nil } func (r *Repository) UpdatePassword(ctx context.Context, userID, newPassword string) error { if len(newPassword) < 8 { return ErrWeakPassword } hash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost) if err != nil { return err } _, err = r.db.ExecContext(ctx, "UPDATE auth_users SET password_hash = $2, updated_at = NOW() WHERE id = $1", userID, string(hash), ) return err } func (r *Repository) UpdateProfile(ctx context.Context, userID, name, avatar string) error { _, err := r.db.ExecContext(ctx, "UPDATE auth_users SET name = $2, avatar = $3, updated_at = NOW() WHERE id = $1", userID, name, avatar, ) return err } func (r *Repository) UpdateTier(ctx context.Context, userID string, tier UserTier) error { _, err := r.db.ExecContext(ctx, "UPDATE auth_users SET tier = $2, updated_at = NOW() WHERE id = $1", userID, string(tier), ) return err } func (r *Repository) UpdateRole(ctx context.Context, userID string, role UserRole) error { _, err := r.db.ExecContext(ctx, "UPDATE auth_users SET role = $2, updated_at = NOW() WHERE id = $1", userID, string(role), ) return err } func (r *Repository) UpdateBalance(ctx context.Context, userID string, amount float64) error { _, err := r.db.ExecContext(ctx, "UPDATE auth_users SET balance = balance + $2, updated_at = NOW() WHERE id = $1", userID, amount, ) return err } func (r *Repository) SetBalance(ctx context.Context, userID string, balance float64) error { _, err := r.db.ExecContext(ctx, "UPDATE auth_users SET balance = $2, updated_at = NOW() WHERE id = $1", userID, balance, ) return err } func (r *Repository) CreateRefreshToken(ctx context.Context, userID, userAgent, ip string, duration time.Duration) (*RefreshToken, error) { token := generateSecureToken(32) rt := &RefreshToken{ UserID: userID, Token: token, UserAgent: userAgent, IP: ip, ExpiresAt: time.Now().Add(duration), } query := ` INSERT INTO refresh_tokens (user_id, token, user_agent, ip, expires_at) VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at ` err := r.db.QueryRowContext(ctx, query, rt.UserID, rt.Token, rt.UserAgent, rt.IP, rt.ExpiresAt, ).Scan(&rt.ID, &rt.CreatedAt) if err != nil { return nil, err } return rt, nil } func (r *Repository) ValidateRefreshToken(ctx context.Context, token string) (*RefreshToken, error) { query := ` SELECT id, user_id, token, user_agent, ip, expires_at, created_at FROM refresh_tokens WHERE token = $1 ` rt := &RefreshToken{} err := r.db.QueryRowContext(ctx, query, token).Scan( &rt.ID, &rt.UserID, &rt.Token, &rt.UserAgent, &rt.IP, &rt.ExpiresAt, &rt.CreatedAt, ) if err == sql.ErrNoRows { return nil, ErrTokenInvalid } if err != nil { return nil, err } if time.Now().After(rt.ExpiresAt) { r.db.ExecContext(ctx, "DELETE FROM refresh_tokens WHERE id = $1", rt.ID) return nil, ErrTokenExpired } return rt, nil } func (r *Repository) RevokeRefreshToken(ctx context.Context, token string) error { _, err := r.db.ExecContext(ctx, "DELETE FROM refresh_tokens WHERE token = $1", token) return err } func (r *Repository) RevokeAllRefreshTokens(ctx context.Context, userID string) error { _, err := r.db.ExecContext(ctx, "DELETE FROM refresh_tokens WHERE user_id = $1", userID) return err } func (r *Repository) CreatePasswordResetToken(ctx context.Context, userID string) (*PasswordResetToken, error) { token := generateSecureToken(32) prt := &PasswordResetToken{ UserID: userID, Token: token, ExpiresAt: time.Now().Add(1 * time.Hour), } query := ` INSERT INTO password_reset_tokens (user_id, token, expires_at) VALUES ($1, $2, $3) RETURNING id, created_at ` err := r.db.QueryRowContext(ctx, query, prt.UserID, prt.Token, prt.ExpiresAt, ).Scan(&prt.ID, &prt.CreatedAt) if err != nil { return nil, err } return prt, nil } func (r *Repository) ValidatePasswordResetToken(ctx context.Context, token string) (*PasswordResetToken, error) { query := ` SELECT id, user_id, token, expires_at, used, created_at FROM password_reset_tokens WHERE token = $1 ` prt := &PasswordResetToken{} err := r.db.QueryRowContext(ctx, query, token).Scan( &prt.ID, &prt.UserID, &prt.Token, &prt.ExpiresAt, &prt.Used, &prt.CreatedAt, ) if err == sql.ErrNoRows { return nil, ErrTokenInvalid } if err != nil { return nil, err } if prt.Used { return nil, ErrTokenInvalid } if time.Now().After(prt.ExpiresAt) { return nil, ErrTokenExpired } return prt, nil } func (r *Repository) MarkPasswordResetTokenUsed(ctx context.Context, tokenID string) error { _, err := r.db.ExecContext(ctx, "UPDATE password_reset_tokens SET used = TRUE WHERE id = $1", tokenID) return err } func (r *Repository) CreateOrUpdateOAuthUser(ctx context.Context, provider, providerID, email, name, avatar string) (*User, error) { query := ` SELECT id FROM auth_users WHERE provider = $1 AND provider_id = $2 ` var existingID string err := r.db.QueryRowContext(ctx, query, provider, providerID).Scan(&existingID) if err == sql.ErrNoRows { var emailExists bool r.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM auth_users WHERE email = $1)", email).Scan(&emailExists) if emailExists { _, err := r.db.ExecContext(ctx, "UPDATE auth_users SET provider = $1, provider_id = $2, last_login_at = NOW() WHERE email = $3", provider, providerID, email, ) if err != nil { return nil, err } return r.GetUserByEmail(ctx, email) } user := &User{ Email: email, Name: name, Avatar: avatar, Role: string(RoleUser), Tier: string(TierFree), Provider: provider, ProviderID: providerID, EmailVerified: true, } insertQuery := ` INSERT INTO auth_users (email, name, avatar, role, tier, provider, provider_id, email_verified) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at ` err = r.db.QueryRowContext(ctx, insertQuery, user.Email, user.Name, user.Avatar, user.Role, user.Tier, user.Provider, user.ProviderID, user.EmailVerified, ).Scan(&user.ID, &user.CreatedAt, &user.UpdatedAt) if err != nil { return nil, err } return user, nil } if err != nil { return nil, err } _, err = r.db.ExecContext(ctx, "UPDATE auth_users SET name = $2, avatar = $3, last_login_at = NOW() WHERE id = $1", existingID, name, avatar, ) if err != nil { return nil, err } return r.GetUserByID(ctx, existingID) } func (r *Repository) CleanupExpiredTokens(ctx context.Context) error { _, err := r.db.ExecContext(ctx, "DELETE FROM refresh_tokens WHERE expires_at < NOW()") if err != nil { return err } _, err = r.db.ExecContext(ctx, "DELETE FROM password_reset_tokens WHERE expires_at < NOW()") return err } func generateSecureToken(length int) string { b := make([]byte, length) rand.Read(b) return hex.EncodeToString(b) }