Refactor and improve code consistency across multiple files

- Enhanced SQL queries in `session.go` and `share.go` for clarity and consistency.
- Updated comments for better understanding and maintenance.
- Ensured consistent error handling and return statements across various methods.
This commit is contained in:
M1ngdaXie
2026-02-04 22:01:47 -08:00
parent 0f4cff89a2
commit c84cbafb2c
18 changed files with 629 additions and 631 deletions

View File

@@ -12,77 +12,77 @@ import (
// CreateSession creates a new session
func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) {
// Hash the token before storing
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
// Hash the token before storing
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
// 【修改点 1】: 在 SQL 里显式加上 id 字段
// 注意:$1 变成了 id后面的参数序号全部要顺延 (+1)
query := `
// 【修改点 1】: 在 SQL 里显式加上 id 字段
// 注意:$1 变成了 id后面的参数序号全部要顺延 (+1)
query := `
INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
`
var session models.Session
// 【修改点 2】: 在参数列表的最前面加上 sessionID
// 现在的对应关系:
// $1 -> sessionID
// $2 -> userID
// $3 -> tokenHash
// ...
err := s.db.QueryRowContext(ctx, query,
sessionID, // <--- 这里!把它传进去!
userID,
tokenHash,
expiresAt,
userAgent,
ipAddress,
).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
var session models.Session
// 【修改点 2】: 在参数列表的最前面加上 sessionID
// 现在的对应关系:
// $1 -> sessionID
// $2 -> userID
// $3 -> tokenHash
// ...
err := s.db.QueryRowContext(ctx, query,
sessionID, // <--- 这里!把它传进去!
userID,
tokenHash,
expiresAt,
userAgent,
ipAddress,
).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
return &session, nil
return &session, nil
}
// GetSessionByToken retrieves session by JWT token
func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
query := `
query := `
SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
FROM sessions
WHERE token_hash = $1 AND expires_at > NOW()
`
var session models.Session
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
var session models.Session
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
return &session, nil
return &session, nil
}
// DeleteSession deletes a session (logout)
func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
return err
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
return err
}
// CleanupExpiredSessions removes expired sessions
func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
return err
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
return err
}