Refactor and improve code consistency across multiple files
- Enhanced SQL queries in `session.go` and `share.go` for clarity and consistency. - Updated comments for better understanding and maintenance. - Ensured consistent error handling and return statements across various methods.
This commit is contained in:
@@ -13,33 +13,33 @@ import (
|
||||
|
||||
// Store interface defines all database operations
|
||||
type Store interface {
|
||||
// Document operations
|
||||
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
|
||||
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
|
||||
GetDocument(id uuid.UUID) (*models.Document, error)
|
||||
ListDocuments() ([]models.Document, error)
|
||||
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
|
||||
UpdateDocumentState(id uuid.UUID, state []byte) error
|
||||
DeleteDocument(id uuid.UUID) error
|
||||
|
||||
// User operations
|
||||
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
|
||||
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
|
||||
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
|
||||
|
||||
// Session operations
|
||||
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
|
||||
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
|
||||
DeleteSession(ctx context.Context, token string) error
|
||||
CleanupExpiredSessions(ctx context.Context) error
|
||||
|
||||
// Share operations
|
||||
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
|
||||
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
|
||||
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
|
||||
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||
// Document operations
|
||||
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
|
||||
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
|
||||
GetDocument(id uuid.UUID) (*models.Document, error)
|
||||
ListDocuments() ([]models.Document, error)
|
||||
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
|
||||
UpdateDocumentState(id uuid.UUID, state []byte) error
|
||||
DeleteDocument(id uuid.UUID) error
|
||||
|
||||
// User operations
|
||||
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
|
||||
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
|
||||
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
|
||||
|
||||
// Session operations
|
||||
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
|
||||
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
|
||||
DeleteSession(ctx context.Context, token string) error
|
||||
CleanupExpiredSessions(ctx context.Context) error
|
||||
|
||||
// Share operations
|
||||
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
|
||||
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
|
||||
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
|
||||
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||
GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error)
|
||||
ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error)
|
||||
RevokeShareToken(ctx context.Context, documentID uuid.UUID) error
|
||||
@@ -53,13 +53,11 @@ type Store interface {
|
||||
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
|
||||
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
|
||||
|
||||
|
||||
Close() error
|
||||
Close() error
|
||||
}
|
||||
|
||||
|
||||
type PostgresStore struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
|
||||
@@ -68,17 +66,17 @@ func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
|
||||
return nil, error
|
||||
}
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
return &PostgresStore{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *PostgresStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
|
||||
doc := &models.Document{
|
||||
@@ -95,12 +93,11 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
|
||||
RETURNING id, name, type, created_at, updated_at
|
||||
`
|
||||
err := s.db.QueryRow(query,
|
||||
doc.ID,
|
||||
doc.Name,
|
||||
doc.Type,
|
||||
doc.CreatedAt,
|
||||
doc.ID,
|
||||
doc.Name,
|
||||
doc.Type,
|
||||
doc.CreatedAt,
|
||||
doc.UpdatedAt,
|
||||
|
||||
).Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
|
||||
|
||||
if err != nil {
|
||||
@@ -109,164 +106,163 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// GetDocument retrieves a document by ID
|
||||
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
|
||||
doc := &models.Document{}
|
||||
// GetDocument retrieves a document by ID
|
||||
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
|
||||
doc := &models.Document{}
|
||||
|
||||
query := `
|
||||
query := `
|
||||
SELECT id, name, type, yjs_state, owner_id, is_public, created_at, updated_at
|
||||
FROM documents
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
err := s.db.QueryRow(query, id).Scan(
|
||||
&doc.ID,
|
||||
&doc.Name,
|
||||
&doc.Type,
|
||||
&doc.YjsState,
|
||||
&doc.OwnerID,
|
||||
&doc.Is_Public,
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
)
|
||||
err := s.db.QueryRow(query, id).Scan(
|
||||
&doc.ID,
|
||||
&doc.Name,
|
||||
&doc.Type,
|
||||
&doc.YjsState,
|
||||
&doc.OwnerID,
|
||||
&doc.Is_Public,
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("document not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get document: %w", err)
|
||||
}
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, fmt.Errorf("document not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get document: %w", err)
|
||||
}
|
||||
|
||||
return doc, nil
|
||||
}
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
|
||||
// ListDocuments retrieves all documents
|
||||
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
|
||||
query := `
|
||||
// ListDocuments retrieves all documents
|
||||
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
|
||||
query := `
|
||||
SELECT id, name, type, created_at, updated_at
|
||||
FROM documents
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := s.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list documents: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
rows, err := s.db.Query(query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list documents: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var documents []models.Document
|
||||
for rows.Next() {
|
||||
var doc models.Document
|
||||
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan document: %w", err)
|
||||
}
|
||||
documents = append(documents, doc)
|
||||
}
|
||||
var documents []models.Document
|
||||
for rows.Next() {
|
||||
var doc models.Document
|
||||
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan document: %w", err)
|
||||
}
|
||||
documents = append(documents, doc)
|
||||
}
|
||||
|
||||
return documents, nil
|
||||
}
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
|
||||
query := `
|
||||
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
|
||||
query := `
|
||||
UPDATE documents
|
||||
SET yjs_state = $1, updated_at = $2
|
||||
WHERE id = $3
|
||||
`
|
||||
|
||||
result, err := s.db.Exec(query, state, time.Now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update document state: %w", err)
|
||||
}
|
||||
result, err := s.db.Exec(query, state, time.Now(), id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update document state: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("document not found")
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("document not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
|
||||
query := `DELETE FROM documents WHERE id = $1`
|
||||
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
|
||||
query := `DELETE FROM documents WHERE id = $1`
|
||||
|
||||
result, err := s.db.Exec(query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete document: %w", err)
|
||||
}
|
||||
result, err := s.db.Exec(query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete document: %w", err)
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("document not found")
|
||||
}
|
||||
if rowsAffected == 0 {
|
||||
return fmt.Errorf("document not found")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateDocumentWithOwner creates a new document with owner
|
||||
func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) {
|
||||
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
|
||||
if docType == "" {
|
||||
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
|
||||
}
|
||||
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
|
||||
if docType == "" {
|
||||
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
|
||||
}
|
||||
|
||||
// Validate that docType is one of the allowed values
|
||||
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
|
||||
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
|
||||
}
|
||||
// Validate that docType is one of the allowed values
|
||||
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
|
||||
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
|
||||
}
|
||||
|
||||
doc := &models.Document{
|
||||
ID: uuid.New(),
|
||||
Name: name,
|
||||
Type: docType,
|
||||
YjsState: []byte{}, // 这里初始化了空字节
|
||||
OwnerID: ownerID,
|
||||
Is_Public: false, // 显式设置默认值
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
doc := &models.Document{
|
||||
ID: uuid.New(),
|
||||
Name: name,
|
||||
Type: docType,
|
||||
YjsState: []byte{}, // 这里初始化了空字节
|
||||
OwnerID: ownerID,
|
||||
Is_Public: false, // 显式设置默认值
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 2. 补全了 yjs_state 和 is_public
|
||||
query := `
|
||||
// 2. 补全了 yjs_state 和 is_public
|
||||
query := `
|
||||
INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at
|
||||
`
|
||||
|
||||
// 3. Scan 的时候也要对应加上
|
||||
err := s.db.QueryRow(query,
|
||||
doc.ID,
|
||||
doc.Name,
|
||||
doc.Type,
|
||||
doc.OwnerID,
|
||||
doc.YjsState, // $5
|
||||
doc.Is_Public, // $6
|
||||
doc.CreatedAt,
|
||||
doc.UpdatedAt,
|
||||
).Scan(
|
||||
&doc.ID,
|
||||
&doc.Name,
|
||||
&doc.Type,
|
||||
&doc.OwnerID,
|
||||
&doc.YjsState, // Scan 回来
|
||||
&doc.Is_Public, // Scan 回来
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create document: %w", err)
|
||||
}
|
||||
return doc, nil
|
||||
// 3. Scan 的时候也要对应加上
|
||||
err := s.db.QueryRow(query,
|
||||
doc.ID,
|
||||
doc.Name,
|
||||
doc.Type,
|
||||
doc.OwnerID,
|
||||
doc.YjsState, // $5
|
||||
doc.Is_Public, // $6
|
||||
doc.CreatedAt,
|
||||
doc.UpdatedAt,
|
||||
).Scan(
|
||||
&doc.ID,
|
||||
&doc.Name,
|
||||
&doc.Type,
|
||||
&doc.OwnerID,
|
||||
&doc.YjsState, // Scan 回来
|
||||
&doc.Is_Public, // Scan 回来
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create document: %w", err)
|
||||
}
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// ListUserDocuments lists documents owned by or shared with a user
|
||||
|
||||
@@ -12,77 +12,77 @@ import (
|
||||
|
||||
// CreateSession creates a new session
|
||||
func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) {
|
||||
// Hash the token before storing
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
// Hash the token before storing
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// 【修改点 1】: 在 SQL 里显式加上 id 字段
|
||||
// 注意:$1 变成了 id,后面的参数序号全部要顺延 (+1)
|
||||
query := `
|
||||
// 【修改点 1】: 在 SQL 里显式加上 id 字段
|
||||
// 注意:$1 变成了 id,后面的参数序号全部要顺延 (+1)
|
||||
query := `
|
||||
INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
|
||||
`
|
||||
|
||||
var session models.Session
|
||||
// 【修改点 2】: 在参数列表的最前面加上 sessionID
|
||||
// 现在的对应关系:
|
||||
// $1 -> sessionID
|
||||
// $2 -> userID
|
||||
// $3 -> tokenHash
|
||||
// ...
|
||||
err := s.db.QueryRowContext(ctx, query,
|
||||
sessionID, // <--- 这里!把它传进去!
|
||||
userID,
|
||||
tokenHash,
|
||||
expiresAt,
|
||||
userAgent,
|
||||
ipAddress,
|
||||
).Scan(
|
||||
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
|
||||
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var session models.Session
|
||||
// 【修改点 2】: 在参数列表的最前面加上 sessionID
|
||||
// 现在的对应关系:
|
||||
// $1 -> sessionID
|
||||
// $2 -> userID
|
||||
// $3 -> tokenHash
|
||||
// ...
|
||||
err := s.db.QueryRowContext(ctx, query,
|
||||
sessionID, // <--- 这里!把它传进去!
|
||||
userID,
|
||||
tokenHash,
|
||||
expiresAt,
|
||||
userAgent,
|
||||
ipAddress,
|
||||
).Scan(
|
||||
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
|
||||
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// GetSessionByToken retrieves session by JWT token
|
||||
func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
query := `
|
||||
query := `
|
||||
SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
|
||||
FROM sessions
|
||||
WHERE token_hash = $1 AND expires_at > NOW()
|
||||
`
|
||||
|
||||
var session models.Session
|
||||
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
|
||||
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
|
||||
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var session models.Session
|
||||
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
|
||||
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
|
||||
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session (logout)
|
||||
func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
|
||||
return err
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
|
||||
return err
|
||||
}
|
||||
|
||||
// CleanupExpiredSessions removes expired sessions
|
||||
func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error {
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
|
||||
return err
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -14,34 +14,34 @@ import (
|
||||
// CreateDocumentShare creates a new share or updates existing one
|
||||
// Returns the share and a boolean indicating if it was newly created (true) or updated (false)
|
||||
func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) {
|
||||
// First check if share already exists
|
||||
var existingID uuid.UUID
|
||||
checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2`
|
||||
err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID)
|
||||
isNewShare := err != nil // If error (not found), it's a new share
|
||||
// First check if share already exists
|
||||
var existingID uuid.UUID
|
||||
checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2`
|
||||
err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID)
|
||||
isNewShare := err != nil // If error (not found), it's a new share
|
||||
|
||||
query := `
|
||||
query := `
|
||||
INSERT INTO document_shares (document_id, user_id, permission, created_by)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission
|
||||
RETURNING id, document_id, user_id, permission, created_at, created_by
|
||||
`
|
||||
|
||||
var share models.DocumentShare
|
||||
err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
|
||||
&share.ID, &share.DocumentID, &share.UserID, &share.Permission,
|
||||
&share.CreatedAt, &share.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
var share models.DocumentShare
|
||||
err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
|
||||
&share.ID, &share.DocumentID, &share.UserID, &share.Permission,
|
||||
&share.CreatedAt, &share.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return &share, isNewShare, nil
|
||||
return &share, isNewShare, nil
|
||||
}
|
||||
|
||||
// ListDocumentShares lists all shares for a document
|
||||
func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) {
|
||||
query := `
|
||||
query := `
|
||||
SELECT
|
||||
ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by,
|
||||
u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at
|
||||
@@ -51,38 +51,38 @@ func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.
|
||||
ORDER BY ds.created_at DESC
|
||||
`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, documentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
rows, err := s.db.QueryContext(ctx, query, documentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var shares []models.DocumentShareWithUser
|
||||
for rows.Next() {
|
||||
var share models.DocumentShareWithUser
|
||||
err := rows.Scan(
|
||||
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
|
||||
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
|
||||
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shares = append(shares, share)
|
||||
}
|
||||
var shares []models.DocumentShareWithUser
|
||||
for rows.Next() {
|
||||
var share models.DocumentShareWithUser
|
||||
err := rows.Scan(
|
||||
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
|
||||
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
|
||||
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shares = append(shares, share)
|
||||
}
|
||||
|
||||
return shares, nil
|
||||
return shares, nil
|
||||
}
|
||||
|
||||
// DeleteDocumentShare deletes a share
|
||||
func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error {
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
|
||||
return err
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// CanViewDocument checks if user can view document (owner OR has any share)
|
||||
func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
||||
query := `
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
|
||||
UNION
|
||||
@@ -90,14 +90,14 @@ func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID
|
||||
)
|
||||
`
|
||||
|
||||
var canView bool
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
|
||||
return canView, err
|
||||
var canView bool
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
|
||||
return canView, err
|
||||
}
|
||||
|
||||
// CanEditDocument checks if user can edit document (owner OR has edit share)
|
||||
func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
||||
query := `
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
|
||||
UNION
|
||||
@@ -105,21 +105,21 @@ func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID
|
||||
)
|
||||
`
|
||||
|
||||
var canEdit bool
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
|
||||
return canEdit, err
|
||||
var canEdit bool
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
|
||||
return canEdit, err
|
||||
}
|
||||
|
||||
// IsDocumentOwner checks if user is the owner
|
||||
func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
||||
query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
|
||||
query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
|
||||
|
||||
var isOwner bool
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return isOwner, err
|
||||
var isOwner bool
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return isOwner, err
|
||||
}
|
||||
func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) {
|
||||
// Generate random 32-byte token
|
||||
@@ -239,4 +239,4 @@ func (s *PostgresStore) GetShareLinkPermission(ctx context.Context, documentID u
|
||||
}
|
||||
|
||||
return permission, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
// UpsertUser creates or updates user from OAuth profile
|
||||
func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) {
|
||||
query := `
|
||||
query := `
|
||||
INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW())
|
||||
ON CONFLICT (provider, provider_user_id)
|
||||
@@ -24,59 +24,59 @@ func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID
|
||||
RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
||||
`
|
||||
|
||||
var user models.User
|
||||
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
|
||||
var user models.User
|
||||
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
|
||||
|
||||
return &user, nil
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves user by ID
|
||||
func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
|
||||
query := `
|
||||
query := `
|
||||
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
||||
FROM users WHERE id = $1
|
||||
`
|
||||
|
||||
var user models.User
|
||||
err := s.db.QueryRowContext(ctx, query, userID).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var user models.User
|
||||
err := s.db.QueryRowContext(ctx, query, userID).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves user by email
|
||||
func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
|
||||
query := `
|
||||
query := `
|
||||
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
||||
FROM users WHERE email = $1
|
||||
`
|
||||
|
||||
var user models.User
|
||||
err := s.db.QueryRowContext(ctx, query, email).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var user models.User
|
||||
err := s.db.QueryRowContext(ctx, query, email).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user