Refactor and improve code consistency across multiple files

- Enhanced SQL queries in `session.go` and `share.go` for clarity and consistency.
- Updated comments for better understanding and maintenance.
- Ensured consistent error handling and return statements across various methods.
This commit is contained in:
M1ngdaXie
2026-02-04 22:01:47 -08:00
parent 0f4cff89a2
commit c84cbafb2c
18 changed files with 629 additions and 631 deletions

View File

@@ -13,33 +13,33 @@ import (
// Store interface defines all database operations
type Store interface {
// Document operations
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
GetDocument(id uuid.UUID) (*models.Document, error)
ListDocuments() ([]models.Document, error)
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
UpdateDocumentState(id uuid.UUID, state []byte) error
DeleteDocument(id uuid.UUID) error
// User operations
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
// Session operations
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
DeleteSession(ctx context.Context, token string) error
CleanupExpiredSessions(ctx context.Context) error
// Share operations
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
// Document operations
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
GetDocument(id uuid.UUID) (*models.Document, error)
ListDocuments() ([]models.Document, error)
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
UpdateDocumentState(id uuid.UUID, state []byte) error
DeleteDocument(id uuid.UUID) error
// User operations
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
// Session operations
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
DeleteSession(ctx context.Context, token string) error
CleanupExpiredSessions(ctx context.Context) error
// Share operations
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error)
ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error)
RevokeShareToken(ctx context.Context, documentID uuid.UUID) error
@@ -53,13 +53,11 @@ type Store interface {
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
Close() error
Close() error
}
type PostgresStore struct {
db *sql.DB
db *sql.DB
}
func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
@@ -68,17 +66,17 @@ func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
return nil, error
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return nil, fmt.Errorf("failed to ping database: %w", err)
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
return &PostgresStore{db: db}, nil
}
func (s *PostgresStore) Close() error {
return s.db.Close()
}
return s.db.Close()
}
func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
doc := &models.Document{
@@ -95,12 +93,11 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
RETURNING id, name, type, created_at, updated_at
`
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.CreatedAt,
doc.ID,
doc.Name,
doc.Type,
doc.CreatedAt,
doc.UpdatedAt,
).Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
@@ -109,164 +106,163 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
return doc, nil
}
// GetDocument retrieves a document by ID
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
doc := &models.Document{}
// GetDocument retrieves a document by ID
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
doc := &models.Document{}
query := `
query := `
SELECT id, name, type, yjs_state, owner_id, is_public, created_at, updated_at
FROM documents
WHERE id = $1
`
err := s.db.QueryRow(query, id).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.YjsState,
&doc.OwnerID,
&doc.Is_Public,
&doc.CreatedAt,
&doc.UpdatedAt,
)
err := s.db.QueryRow(query, id).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.YjsState,
&doc.OwnerID,
&doc.Is_Public,
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("document not found")
}
if err != nil {
return nil, fmt.Errorf("failed to get document: %w", err)
}
if err == sql.ErrNoRows {
return nil, fmt.Errorf("document not found")
}
if err != nil {
return nil, fmt.Errorf("failed to get document: %w", err)
}
return doc, nil
}
return doc, nil
}
// ListDocuments retrieves all documents
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
query := `
// ListDocuments retrieves all documents
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
query := `
SELECT id, name, type, created_at, updated_at
FROM documents
ORDER BY created_at DESC
`
rows, err := s.db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to list documents: %w", err)
}
defer rows.Close()
rows, err := s.db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to list documents: %w", err)
}
defer rows.Close()
var documents []models.Document
for rows.Next() {
var doc models.Document
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan document: %w", err)
}
documents = append(documents, doc)
}
var documents []models.Document
for rows.Next() {
var doc models.Document
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan document: %w", err)
}
documents = append(documents, doc)
}
return documents, nil
}
return documents, nil
}
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
query := `
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
query := `
UPDATE documents
SET yjs_state = $1, updated_at = $2
WHERE id = $3
`
result, err := s.db.Exec(query, state, time.Now(), id)
if err != nil {
return fmt.Errorf("failed to update document state: %w", err)
}
result, err := s.db.Exec(query, state, time.Now(), id)
if err != nil {
return fmt.Errorf("failed to update document state: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
return nil
}
return nil
}
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
query := `DELETE FROM documents WHERE id = $1`
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
query := `DELETE FROM documents WHERE id = $1`
result, err := s.db.Exec(query, id)
if err != nil {
return fmt.Errorf("failed to delete document: %w", err)
}
result, err := s.db.Exec(query, id)
if err != nil {
return fmt.Errorf("failed to delete document: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
return nil
}
return nil
}
// CreateDocumentWithOwner creates a new document with owner
func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) {
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
if docType == "" {
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
}
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
if docType == "" {
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
}
// Validate that docType is one of the allowed values
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
}
// Validate that docType is one of the allowed values
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
}
doc := &models.Document{
ID: uuid.New(),
Name: name,
Type: docType,
YjsState: []byte{}, // 这里初始化了空字节
OwnerID: ownerID,
Is_Public: false, // 显式设置默认值
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
doc := &models.Document{
ID: uuid.New(),
Name: name,
Type: docType,
YjsState: []byte{}, // 这里初始化了空字节
OwnerID: ownerID,
Is_Public: false, // 显式设置默认值
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 2. 补全了 yjs_state 和 is_public
query := `
// 2. 补全了 yjs_state 和 is_public
query := `
INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at
`
// 3. Scan 的时候也要对应加上
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.OwnerID,
doc.YjsState, // $5
doc.Is_Public, // $6
doc.CreatedAt,
doc.UpdatedAt,
).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.OwnerID,
&doc.YjsState, // Scan 回来
&doc.Is_Public, // Scan 回来
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create document: %w", err)
}
return doc, nil
// 3. Scan 的时候也要对应加上
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.OwnerID,
doc.YjsState, // $5
doc.Is_Public, // $6
doc.CreatedAt,
doc.UpdatedAt,
).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.OwnerID,
&doc.YjsState, // Scan 回来
&doc.Is_Public, // Scan 回来
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create document: %w", err)
}
return doc, nil
}
// ListUserDocuments lists documents owned by or shared with a user