Refactor and improve code consistency across multiple files

- Enhanced SQL queries in `session.go` and `share.go` for clarity and consistency.
- Updated comments for better understanding and maintenance.
- Ensured consistent error handling and return statements across various methods.
This commit is contained in:
M1ngdaXie
2026-02-04 22:01:47 -08:00
parent 0f4cff89a2
commit c84cbafb2c
18 changed files with 629 additions and 631 deletions

View File

@@ -13,33 +13,33 @@ import (
// Store interface defines all database operations
type Store interface {
// Document operations
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
GetDocument(id uuid.UUID) (*models.Document, error)
ListDocuments() ([]models.Document, error)
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
UpdateDocumentState(id uuid.UUID, state []byte) error
DeleteDocument(id uuid.UUID) error
// User operations
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
// Session operations
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
DeleteSession(ctx context.Context, token string) error
CleanupExpiredSessions(ctx context.Context) error
// Share operations
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
// Document operations
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
GetDocument(id uuid.UUID) (*models.Document, error)
ListDocuments() ([]models.Document, error)
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
UpdateDocumentState(id uuid.UUID, state []byte) error
DeleteDocument(id uuid.UUID) error
// User operations
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
// Session operations
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
DeleteSession(ctx context.Context, token string) error
CleanupExpiredSessions(ctx context.Context) error
// Share operations
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error)
ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error)
RevokeShareToken(ctx context.Context, documentID uuid.UUID) error
@@ -53,13 +53,11 @@ type Store interface {
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
Close() error
Close() error
}
type PostgresStore struct {
db *sql.DB
db *sql.DB
}
func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
@@ -68,17 +66,17 @@ func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
return nil, error
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return nil, fmt.Errorf("failed to ping database: %w", err)
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
return &PostgresStore{db: db}, nil
}
func (s *PostgresStore) Close() error {
return s.db.Close()
}
return s.db.Close()
}
func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
doc := &models.Document{
@@ -95,12 +93,11 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
RETURNING id, name, type, created_at, updated_at
`
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.CreatedAt,
doc.ID,
doc.Name,
doc.Type,
doc.CreatedAt,
doc.UpdatedAt,
).Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
@@ -109,164 +106,163 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
return doc, nil
}
// GetDocument retrieves a document by ID
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
doc := &models.Document{}
// GetDocument retrieves a document by ID
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
doc := &models.Document{}
query := `
query := `
SELECT id, name, type, yjs_state, owner_id, is_public, created_at, updated_at
FROM documents
WHERE id = $1
`
err := s.db.QueryRow(query, id).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.YjsState,
&doc.OwnerID,
&doc.Is_Public,
&doc.CreatedAt,
&doc.UpdatedAt,
)
err := s.db.QueryRow(query, id).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.YjsState,
&doc.OwnerID,
&doc.Is_Public,
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("document not found")
}
if err != nil {
return nil, fmt.Errorf("failed to get document: %w", err)
}
if err == sql.ErrNoRows {
return nil, fmt.Errorf("document not found")
}
if err != nil {
return nil, fmt.Errorf("failed to get document: %w", err)
}
return doc, nil
}
return doc, nil
}
// ListDocuments retrieves all documents
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
query := `
// ListDocuments retrieves all documents
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
query := `
SELECT id, name, type, created_at, updated_at
FROM documents
ORDER BY created_at DESC
`
rows, err := s.db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to list documents: %w", err)
}
defer rows.Close()
rows, err := s.db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to list documents: %w", err)
}
defer rows.Close()
var documents []models.Document
for rows.Next() {
var doc models.Document
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan document: %w", err)
}
documents = append(documents, doc)
}
var documents []models.Document
for rows.Next() {
var doc models.Document
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan document: %w", err)
}
documents = append(documents, doc)
}
return documents, nil
}
return documents, nil
}
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
query := `
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
query := `
UPDATE documents
SET yjs_state = $1, updated_at = $2
WHERE id = $3
`
result, err := s.db.Exec(query, state, time.Now(), id)
if err != nil {
return fmt.Errorf("failed to update document state: %w", err)
}
result, err := s.db.Exec(query, state, time.Now(), id)
if err != nil {
return fmt.Errorf("failed to update document state: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
return nil
}
return nil
}
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
query := `DELETE FROM documents WHERE id = $1`
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
query := `DELETE FROM documents WHERE id = $1`
result, err := s.db.Exec(query, id)
if err != nil {
return fmt.Errorf("failed to delete document: %w", err)
}
result, err := s.db.Exec(query, id)
if err != nil {
return fmt.Errorf("failed to delete document: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
return nil
}
return nil
}
// CreateDocumentWithOwner creates a new document with owner
func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) {
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
if docType == "" {
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
}
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
if docType == "" {
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
}
// Validate that docType is one of the allowed values
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
}
// Validate that docType is one of the allowed values
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
}
doc := &models.Document{
ID: uuid.New(),
Name: name,
Type: docType,
YjsState: []byte{}, // 这里初始化了空字节
OwnerID: ownerID,
Is_Public: false, // 显式设置默认值
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
doc := &models.Document{
ID: uuid.New(),
Name: name,
Type: docType,
YjsState: []byte{}, // 这里初始化了空字节
OwnerID: ownerID,
Is_Public: false, // 显式设置默认值
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 2. 补全了 yjs_state 和 is_public
query := `
// 2. 补全了 yjs_state 和 is_public
query := `
INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at
`
// 3. Scan 的时候也要对应加上
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.OwnerID,
doc.YjsState, // $5
doc.Is_Public, // $6
doc.CreatedAt,
doc.UpdatedAt,
).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.OwnerID,
&doc.YjsState, // Scan 回来
&doc.Is_Public, // Scan 回来
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create document: %w", err)
}
return doc, nil
// 3. Scan 的时候也要对应加上
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.OwnerID,
doc.YjsState, // $5
doc.Is_Public, // $6
doc.CreatedAt,
doc.UpdatedAt,
).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.OwnerID,
&doc.YjsState, // Scan 回来
&doc.Is_Public, // Scan 回来
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create document: %w", err)
}
return doc, nil
}
// ListUserDocuments lists documents owned by or shared with a user

View File

@@ -12,77 +12,77 @@ import (
// CreateSession creates a new session
func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) {
// Hash the token before storing
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
// Hash the token before storing
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
// 【修改点 1】: 在 SQL 里显式加上 id 字段
// 注意:$1 变成了 id后面的参数序号全部要顺延 (+1)
query := `
// 【修改点 1】: 在 SQL 里显式加上 id 字段
// 注意:$1 变成了 id后面的参数序号全部要顺延 (+1)
query := `
INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
`
var session models.Session
// 【修改点 2】: 在参数列表的最前面加上 sessionID
// 现在的对应关系:
// $1 -> sessionID
// $2 -> userID
// $3 -> tokenHash
// ...
err := s.db.QueryRowContext(ctx, query,
sessionID, // <--- 这里!把它传进去!
userID,
tokenHash,
expiresAt,
userAgent,
ipAddress,
).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
var session models.Session
// 【修改点 2】: 在参数列表的最前面加上 sessionID
// 现在的对应关系:
// $1 -> sessionID
// $2 -> userID
// $3 -> tokenHash
// ...
err := s.db.QueryRowContext(ctx, query,
sessionID, // <--- 这里!把它传进去!
userID,
tokenHash,
expiresAt,
userAgent,
ipAddress,
).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
return &session, nil
return &session, nil
}
// GetSessionByToken retrieves session by JWT token
func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
query := `
query := `
SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
FROM sessions
WHERE token_hash = $1 AND expires_at > NOW()
`
var session models.Session
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
var session models.Session
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
return &session, nil
return &session, nil
}
// DeleteSession deletes a session (logout)
func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
return err
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
return err
}
// CleanupExpiredSessions removes expired sessions
func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
return err
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
return err
}

View File

@@ -14,34 +14,34 @@ import (
// CreateDocumentShare creates a new share or updates existing one
// Returns the share and a boolean indicating if it was newly created (true) or updated (false)
func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) {
// First check if share already exists
var existingID uuid.UUID
checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2`
err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID)
isNewShare := err != nil // If error (not found), it's a new share
// First check if share already exists
var existingID uuid.UUID
checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2`
err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID)
isNewShare := err != nil // If error (not found), it's a new share
query := `
query := `
INSERT INTO document_shares (document_id, user_id, permission, created_by)
VALUES ($1, $2, $3, $4)
ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission
RETURNING id, document_id, user_id, permission, created_at, created_by
`
var share models.DocumentShare
err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission,
&share.CreatedAt, &share.CreatedBy,
)
if err != nil {
return nil, false, err
}
var share models.DocumentShare
err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission,
&share.CreatedAt, &share.CreatedBy,
)
if err != nil {
return nil, false, err
}
return &share, isNewShare, nil
return &share, isNewShare, nil
}
// ListDocumentShares lists all shares for a document
func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) {
query := `
query := `
SELECT
ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by,
u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at
@@ -51,38 +51,38 @@ func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.
ORDER BY ds.created_at DESC
`
rows, err := s.db.QueryContext(ctx, query, documentID)
if err != nil {
return nil, err
}
defer rows.Close()
rows, err := s.db.QueryContext(ctx, query, documentID)
if err != nil {
return nil, err
}
defer rows.Close()
var shares []models.DocumentShareWithUser
for rows.Next() {
var share models.DocumentShareWithUser
err := rows.Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
)
if err != nil {
return nil, err
}
shares = append(shares, share)
}
var shares []models.DocumentShareWithUser
for rows.Next() {
var share models.DocumentShareWithUser
err := rows.Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
)
if err != nil {
return nil, err
}
shares = append(shares, share)
}
return shares, nil
return shares, nil
}
// DeleteDocumentShare deletes a share
func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
return err
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
return err
}
// CanViewDocument checks if user can view document (owner OR has any share)
func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := `
query := `
SELECT EXISTS(
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
UNION
@@ -90,14 +90,14 @@ func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID
)
`
var canView bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
return canView, err
var canView bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
return canView, err
}
// CanEditDocument checks if user can edit document (owner OR has edit share)
func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := `
query := `
SELECT EXISTS(
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
UNION
@@ -105,21 +105,21 @@ func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID
)
`
var canEdit bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
return canEdit, err
var canEdit bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
return canEdit, err
}
// IsDocumentOwner checks if user is the owner
func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
var isOwner bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
if err == sql.ErrNoRows {
return false, nil
}
return isOwner, err
var isOwner bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
if err == sql.ErrNoRows {
return false, nil
}
return isOwner, err
}
func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) {
// Generate random 32-byte token
@@ -239,4 +239,4 @@ func (s *PostgresStore) GetShareLinkPermission(ctx context.Context, documentID u
}
return permission, nil
}
}

View File

@@ -11,7 +11,7 @@ import (
// UpsertUser creates or updates user from OAuth profile
func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) {
query := `
query := `
INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at)
VALUES ($1, $2, $3, $4, $5, NOW())
ON CONFLICT (provider, provider_user_id)
@@ -24,59 +24,59 @@ func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID
RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
`
var user models.User
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err != nil {
return nil, err
}
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
var user models.User
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err != nil {
return nil, err
}
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
return &user, nil
return &user, nil
}
// GetUserByID retrieves user by ID
func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
query := `
query := `
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
FROM users WHERE id = $1
`
var user models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
var user models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &user, nil
return &user, nil
}
// GetUserByEmail retrieves user by email
func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
query := `
query := `
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
FROM users WHERE email = $1
`
var user models.User
err := s.db.QueryRowContext(ctx, query, email).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
var user models.User
err := s.db.QueryRowContext(ctx, query, email).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &user, nil
return &user, nil
}