Refactor and improve code consistency across multiple files

- Enhanced SQL queries in `session.go` and `share.go` for clarity and consistency.
- Updated comments for better understanding and maintenance.
- Ensured consistent error handling and return statements across various methods.
This commit is contained in:
M1ngdaXie
2026-02-04 22:01:47 -08:00
parent 0f4cff89a2
commit c84cbafb2c
18 changed files with 629 additions and 631 deletions

2
.gitignore vendored
View File

@@ -33,3 +33,5 @@ build/
# Docker volumes and data # Docker volumes and data
postgres_data/ postgres_data/
.claude/

View File

@@ -60,4 +60,4 @@ func ValidateJWT(tokenString, secret string) (*UserClaims, error) {
} }
return nil, errors.New("invalid token claims") return nil, errors.New("invalid token claims")
} }

View File

@@ -18,142 +18,142 @@ const ContextUserIDKey = "user_id"
// AuthMiddleware provides auth middleware // AuthMiddleware provides auth middleware
type AuthMiddleware struct { type AuthMiddleware struct {
store store.Store store store.Store
jwtSecret string jwtSecret string
} }
// NewAuthMiddleware creates a new auth middleware // NewAuthMiddleware creates a new auth middleware
func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware { func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware {
return &AuthMiddleware{ return &AuthMiddleware{
store: store, store: store,
jwtSecret: jwtSecret, jwtSecret: jwtSecret,
} }
} }
// RequireAuth middleware requires valid authentication // RequireAuth middleware requires valid authentication
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc { func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
fmt.Println("🔒 RequireAuth: Starting authentication check") fmt.Println("🔒 RequireAuth: Starting authentication check")
user, claims, err := m.getUserFromToken(c) user, claims, err := m.getUserFromToken(c)
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err) fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
if claims != nil { if claims != nil {
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email) fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
} }
if err != nil || user == nil { if err != nil || user == nil {
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user) fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort() c.Abort()
return return
} }
// Note: Name and Email might be empty for old JWT tokens // Note: Name and Email might be empty for old JWT tokens
if claims.Name == "" || claims.Email == "" { if claims.Name == "" || claims.Email == "" {
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n") fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
} }
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user) fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
c.Set(ContextUserIDKey, user) c.Set(ContextUserIDKey, user)
c.Set("user_email", claims.Email) c.Set("user_email", claims.Email)
c.Set("user_name", claims.Name) c.Set("user_name", claims.Name)
if claims.AvatarURL != nil { if claims.AvatarURL != nil {
c.Set("avatar_url", *claims.AvatarURL) c.Set("avatar_url", *claims.AvatarURL)
} }
c.Next() c.Next()
} }
} }
// OptionalAuth middleware sets user if authenticated, but doesn't require it // OptionalAuth middleware sets user if authenticated, but doesn't require it
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc { func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
user, claims, _ := m.getUserFromToken(c) user, claims, _ := m.getUserFromToken(c)
if user != nil { if user != nil {
c.Set(string(UserContextKey), user) c.Set(string(UserContextKey), user)
c.Set(ContextUserIDKey, user) c.Set(ContextUserIDKey, user)
if claims != nil { if claims != nil {
c.Set("user_email", claims.Email) c.Set("user_email", claims.Email)
c.Set("user_name", claims.Name) c.Set("user_name", claims.Name)
if claims.AvatarURL != nil { if claims.AvatarURL != nil {
c.Set("avatar_url", *claims.AvatarURL) c.Set("avatar_url", *claims.AvatarURL)
} }
} }
} }
c.Next() c.Next()
} }
} }
// getUserFromToken parses the JWT and returns the UserID and the full Claims (for name/email) // getUserFromToken parses the JWT and returns the UserID and the full Claims (for name/email)
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error) // 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) { func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader) fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
if authHeader == "" { if authHeader == "" {
fmt.Println("⚠️ getUserFromToken: No Authorization header") fmt.Println("⚠️ getUserFromToken: No Authorization header")
return nil, nil, nil return nil, nil, nil
} }
parts := strings.Split(authHeader, " ") parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" { if len(parts) != 2 || parts[0] != "Bearer" {
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0]) fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
return nil, nil, nil return nil, nil, nil
} }
tokenString := parts[1] tokenString := parts[1]
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))]) fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
// 必须要验证签名算法是 HMAC (HS256) // 必须要验证签名算法是 HMAC (HS256)
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
} }
return []byte(m.jwtSecret), nil return []byte(m.jwtSecret), nil
}) })
if err != nil { if err != nil {
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err) fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
return nil, nil, err return nil, nil, err
} }
// 2. 验证 Token 有效性并提取 Claims // 2. 验证 Token 有效性并提取 Claims
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
// 3. 把 String 类型的 Subject 转回 UUID // 3. 把 String 类型的 Subject 转回 UUID
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String() // 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
userID, err := uuid.Parse(claims.Subject) userID, err := uuid.Parse(claims.Subject)
if err != nil { if err != nil {
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err) fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
return nil, nil, fmt.Errorf("invalid user ID in token") return nil, nil, fmt.Errorf("invalid user ID in token")
} }
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email) // 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
// 这一步完全没有查数据库,速度极快 // 这一步完全没有查数据库,速度极快
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email) fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
return &userID, claims, nil return &userID, claims, nil
} }
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid") fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
return nil, nil, fmt.Errorf("invalid token claims") return nil, nil, fmt.Errorf("invalid token claims")
} }
// GetUserFromContext extracts user ID from context // GetUserFromContext extracts user ID from context
func GetUserFromContext(c *gin.Context) *uuid.UUID { func GetUserFromContext(c *gin.Context) *uuid.UUID {
// 修正点:使用和存入时完全一样的 Key // 修正点:使用和存入时完全一样的 Key
val, exists := c.Get(ContextUserIDKey) val, exists := c.Get(ContextUserIDKey)
fmt.Println("within getFromContext the id is ... ") fmt.Println("within getFromContext the id is ... ")
fmt.Println(val); fmt.Println(val)
if !exists { if !exists {
return nil return nil
} }
// 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型) // 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型)
uid, ok := val.(*uuid.UUID) uid, ok := val.(*uuid.UUID)
if !ok { if !ok {
return nil return nil
} }
return uid return uid
} }
// ValidateToken validates a JWT token and returns user ID, name, and avatar URL from JWT claims // ValidateToken validates a JWT token and returns user ID, name, and avatar URL from JWT claims

View File

@@ -8,25 +8,25 @@ import (
// GetGoogleOAuthConfig returns Google OAuth2 config // GetGoogleOAuthConfig returns Google OAuth2 config
func GetGoogleOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config { func GetGoogleOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
return &oauth2.Config{ return &oauth2.Config{
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
RedirectURL: redirectURL, RedirectURL: redirectURL,
Scopes: []string{ Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile", "https://www.googleapis.com/auth/userinfo.profile",
}, },
Endpoint: google.Endpoint, Endpoint: google.Endpoint,
} }
} }
// GetGitHubOAuthConfig returns GitHub OAuth2 config // GetGitHubOAuthConfig returns GitHub OAuth2 config
func GetGitHubOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config { func GetGitHubOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
return &oauth2.Config{ return &oauth2.Config{
ClientID: clientID, ClientID: clientID,
ClientSecret: clientSecret, ClientSecret: clientSecret,
RedirectURL: redirectURL, RedirectURL: redirectURL,
Scopes: []string{"user:email"}, Scopes: []string{"user:email"},
Endpoint: github.Endpoint, Endpoint: github.Endpoint,
} }
} }

View File

@@ -64,10 +64,10 @@ func (h *AuthHandler) GoogleLogin(c *gin.Context) {
// GoogleCallback handles Google OAuth callback // GoogleCallback handles Google OAuth callback
func (h *AuthHandler) GoogleCallback(c *gin.Context) { func (h *AuthHandler) GoogleCallback(c *gin.Context) {
oauthState, err := c.Cookie("oauthstate") oauthState, err := c.Cookie("oauthstate")
if err != nil || c.Query("state") != oauthState { if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"}) c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return return
} }
log.Println("Google callback state:", c.Query("state")) log.Println("Google callback state:", c.Query("state"))
// Exchange code for token // Exchange code for token
token, err := h.googleConfig.Exchange(c.Request.Context(), c.Query("code")) token, err := h.googleConfig.Exchange(c.Request.Context(), c.Query("code"))
@@ -94,11 +94,11 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
Name string `json:"name"` Name string `json:"name"`
Picture string `json:"picture"` Picture string `json:"picture"`
} }
if err := json.Unmarshal(data, &userInfo); err != nil { if err := json.Unmarshal(data, &userInfo); err != nil {
log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data)) log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"})
return return
} }
log.Println("Google user info:", userInfo) log.Println("Google user info:", userInfo)
// Upsert user in database // Upsert user in database
@@ -118,10 +118,10 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
// Create session and JWT // Create session and JWT
jwt, err := h.createSessionAndJWT(c, user) jwt, err := h.createSessionAndJWT(c, user)
if err != nil { if err != nil {
fmt.Printf("❌ DATABASE ERROR: %v\n", err) fmt.Printf("❌ DATABASE ERROR: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{ c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("CreateSession Error: %v", err), "error": fmt.Sprintf("CreateSession Error: %v", err),
}) })
return return
} }
@@ -140,10 +140,10 @@ func (h *AuthHandler) GithubLogin(c *gin.Context) {
// GithubCallback handles GitHub OAuth callback // GithubCallback handles GitHub OAuth callback
func (h *AuthHandler) GithubCallback(c *gin.Context) { func (h *AuthHandler) GithubCallback(c *gin.Context) {
oauthState, err := c.Cookie("oauthstate") oauthState, err := c.Cookie("oauthstate")
if err != nil || c.Query("state") != oauthState { if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"}) c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return return
} }
log.Println("Github callback state:", c.Query("state")) log.Println("Github callback state:", c.Query("state"))
code := c.Query("code") code := c.Query("code")
if code == "" { if code == "" {
@@ -160,7 +160,7 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
// Get user info from GitHub // Get user info from GitHub
client := h.githubConfig.Client(c.Request.Context(), token) client := h.githubConfig.Client(c.Request.Context(), token)
// Get user profile // Get user profile
resp, err := client.Get("https://api.github.com/user") resp, err := client.Get("https://api.github.com/user")
if err != nil { if err != nil {
@@ -178,10 +178,10 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
AvatarURL string `json:"avatar_url"` AvatarURL string `json:"avatar_url"`
} }
if err := json.Unmarshal(data, &userInfo); err != nil { if err := json.Unmarshal(data, &userInfo); err != nil {
log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data)) log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"})
return return
} }
// If email is not public, fetch it separately // If email is not public, fetch it separately
if userInfo.Email == "" { if userInfo.Email == "" {
@@ -315,10 +315,10 @@ func (h *AuthHandler) generateStateOauthCookie(w http.ResponseWriter) string {
Name: "oauthstate", Name: "oauthstate",
Value: state, Value: state,
Expires: time.Now().Add(10 * time.Minute), Expires: time.Now().Add(10 * time.Minute),
HttpOnly: true, // Prevents JavaScript access (XSS protection) HttpOnly: true, // Prevents JavaScript access (XSS protection)
Secure: h.cfg.SecureCookie, // true in production, false for localhost Secure: h.cfg.SecureCookie, // true in production, false for localhost
SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects
Path: "/", // Ensures cookie is sent to all backend paths Path: "/", // Ensures cookie is sent to all backend paths
} }
http.SetCookie(w, &cookie) http.SetCookie(w, &cookie)

View File

@@ -11,16 +11,15 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
) )
type DocumentHandler struct { type DocumentHandler struct {
store *store.PostgresStore store *store.PostgresStore
} }
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler { func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
return &DocumentHandler{store: s} return &DocumentHandler{store: s}
} }
// CreateDocument creates a new document (requires auth)
// CreateDocument creates a new document (requires auth)
func (h *DocumentHandler) CreateDocument(c *gin.Context) { func (h *DocumentHandler) CreateDocument(c *gin.Context) {
userID := auth.GetUserFromContext(c) userID := auth.GetUserFromContext(c)
if userID == nil { if userID == nil {
@@ -44,7 +43,7 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
c.JSON(http.StatusCreated, doc) c.JSON(http.StatusCreated, doc)
} }
func (h *DocumentHandler) ListDocuments(c *gin.Context) { func (h *DocumentHandler) ListDocuments(c *gin.Context) {
userID := auth.GetUserFromContext(c) userID := auth.GetUserFromContext(c)
fmt.Println("Getting userId, which is : ") fmt.Println("Getting userId, which is : ")
fmt.Println(userID) fmt.Println(userID)
@@ -66,8 +65,7 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
}) })
} }
func (h *DocumentHandler) GetDocument(c *gin.Context) {
func (h *DocumentHandler) GetDocument(c *gin.Context) {
id, err := uuid.Parse(c.Param("id")) id, err := uuid.Parse(c.Param("id"))
if err != nil { if err != nil {
respondBadRequest(c, "Invalid document ID format") respondBadRequest(c, "Invalid document ID format")
@@ -104,8 +102,9 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
c.JSON(http.StatusOK, doc) c.JSON(http.StatusOK, doc)
} }
// GetDocumentState returns the Yjs state for a document
// GetDocumentState retrieves document state (requires view permission) // GetDocumentState returns the Yjs state for a document
// GetDocumentState retrieves document state (requires view permission)
func (h *DocumentHandler) GetDocumentState(c *gin.Context) { func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
id, err := uuid.Parse(c.Param("id")) id, err := uuid.Parse(c.Param("id"))
if err != nil { if err != nil {
@@ -143,7 +142,7 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
c.Data(http.StatusOK, "application/octet-stream", state) c.Data(http.StatusOK, "application/octet-stream", state)
} }
// UpdateDocumentState updates document state (requires edit permission) // UpdateDocumentState updates document state (requires edit permission)
func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) { func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
id, err := uuid.Parse(c.Param("id")) id, err := uuid.Parse(c.Param("id"))
if err != nil { if err != nil {
@@ -195,7 +194,7 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"}) c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
} }
// DeleteDocument deletes a document (owner only) // DeleteDocument deletes a document (owner only)
func (h *DocumentHandler) DeleteDocument(c *gin.Context) { func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
id, err := uuid.Parse(c.Param("id")) id, err := uuid.Parse(c.Param("id"))
if err != nil { if err != nil {
@@ -237,107 +236,109 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
// GetDocumentPermission returns the user's permission level for a document // GetDocumentPermission returns the user's permission level for a document
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) { func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id")) documentID, err := uuid.Parse(c.Param("id"))
if err != nil { if err != nil {
respondBadRequest(c, "Invalid document ID format") respondBadRequest(c, "Invalid document ID format")
return return
} }
// 1. 先检查文档是否存在 (Good practice) // 1. 先检查文档是否存在 (Good practice)
_, err = h.store.GetDocument(documentID) _, err = h.store.GetDocument(documentID)
if err != nil { if err != nil {
respondNotFound(c, "document") respondNotFound(c, "document")
return return
} }
userID := auth.GetUserFromContext(c) userID := auth.GetUserFromContext(c)
shareToken := c.Query("share") shareToken := c.Query("share")
// 定义两个临时变量,用来存两边的结果 // 定义两个临时变量,用来存两边的结果
var userPerm string // 存 document_shares 的结果 var userPerm string // 存 document_shares 的结果
var tokenPerm string // 存 share_token 的结果 var tokenPerm string // 存 share_token 的结果
// ==================================================== // ====================================================
// 步骤 A: 检查个人权限 (Base Permission) // 步骤 A: 检查个人权限 (Base Permission)
// ==================================================== // ====================================================
if userID != nil { if userID != nil {
perm, err := h.store.GetUserPermission(c.Request.Context(), documentID, *userID) perm, err := h.store.GetUserPermission(c.Request.Context(), documentID, *userID)
if err != nil { if err != nil {
respondInternalError(c, "Failed to get user permission", err) respondInternalError(c, "Failed to get user permission", err)
return return
} }
userPerm = perm userPerm = perm
// ⚠️ 注意:如果 perm 是空,这里不报错!继续往下走! // ⚠️ 注意:如果 perm 是空,这里不报错!继续往下走!
} }
// ==================================================== // ====================================================
// 步骤 B: 检查 Token 权限 (Upgrade Permission) // 步骤 B: 检查 Token 权限 (Upgrade Permission)
// ==================================================== // ====================================================
if shareToken != "" { if shareToken != "" {
// 先验证 Token 是否有效 // 先验证 Token 是否有效
valid, err := h.store.ValidateShareToken(c.Request.Context(), documentID, shareToken) valid, err := h.store.ValidateShareToken(c.Request.Context(), documentID, shareToken)
if err != nil { if err != nil {
respondInternalError(c, "Failed to validate token", err) respondInternalError(c, "Failed to validate token", err)
return return
} }
// 只有 Token 有效才去取权限
if valid {
p, err := h.store.GetShareLinkPermission(c.Request.Context(), documentID)
if err != nil {
respondInternalError(c, "Failed to get token permission", err)
return
}
tokenPerm = p
// 处理数据库老数据的 fallback
if tokenPerm == "" { tokenPerm = "view" }
}
}
// ==================================================== // 只有 Token 有效才去取权限
// 步骤 C: ⚡️ 权限合并与计算 (The Brain) if valid {
// ==================================================== p, err := h.store.GetShareLinkPermission(c.Request.Context(), documentID)
if err != nil {
finalPermission := "" respondInternalError(c, "Failed to get token permission", err)
role := "viewer" // 默认角色 return
}
tokenPerm = p
// 处理数据库老数据的 fallback
if tokenPerm == "" {
tokenPerm = "view"
}
}
}
// 1. 如果是 Owner无敌直接返回 // ====================================================
if userPerm == "owner" { // 步骤 C: ⚡️ 权限合并与计算 (The Brain)
finalPermission = "edit" // ====================================================
role = "owner"
// 直接返回,不用看 Token 了
c.JSON(http.StatusOK, models.PermissionResponse{
Permission: finalPermission,
Role: role,
})
return
}
// 2. 比较 User 和 Token取最大值 finalPermission := ""
// 逻辑:只要任意一边给了 "edit",那就是 "edit" role := "viewer" // 默认角色
if userPerm == "edit" || tokenPerm == "edit" {
finalPermission = "edit"
role = "editor"
} else if userPerm == "view" || tokenPerm == "view" {
finalPermission = "view"
role = "viewer"
}
// ==================================================== // 1. 如果是 Owner无敌直接返回
// 步骤 D: 最终判决 if userPerm == "owner" {
// ==================================================== finalPermission = "edit"
if finalPermission == "" { role = "owner"
// 既没个人权限Token 也不对(或者没 Token // 直接返回,不用看 Token
if userID == nil { c.JSON(http.StatusOK, models.PermissionResponse{
respondUnauthorized(c, "Authentication required") // 没登录且没Token Permission: finalPermission,
} else { Role: role,
respondForbidden(c, "You don't have permission") // 登录了但没权限 })
} return
return }
}
c.JSON(http.StatusOK, models.PermissionResponse{ // 2. 比较 User 和 Token取最大值
Permission: finalPermission, // 逻辑:只要任意一边给了 "edit",那就是 "edit"
Role: role, if userPerm == "edit" || tokenPerm == "edit" {
}) finalPermission = "edit"
} role = "editor"
} else if userPerm == "view" || tokenPerm == "view" {
finalPermission = "view"
role = "viewer"
}
// ====================================================
// 步骤 D: 最终判决
// ====================================================
if finalPermission == "" {
// 既没个人权限Token 也不对(或者没 Token
if userID == nil {
respondUnauthorized(c, "Authentication required") // 没登录且没Token
} else {
respondForbidden(c, "You don't have permission") // 登录了但没权限
}
return
}
c.JSON(http.StatusOK, models.PermissionResponse{
Permission: finalPermission,
Role: role,
})
}

View File

@@ -91,5 +91,5 @@ func respondInternalError(c *gin.Context, message string, err error) {
respondWithError(c, http.StatusInternalServerError, "internal_error", message) respondWithError(c, http.StatusInternalServerError, "internal_error", message)
} }
func respondInvalidID(c *gin.Context, message string) { func respondInvalidID(c *gin.Context, message string) {
respondWithError(c, http.StatusBadRequest, "invalid_id", message) respondWithError(c, http.StatusBadRequest, "invalid_id", message)
} }

View File

@@ -153,6 +153,7 @@ func (h *ShareHandler) DeleteShare(c *gin.Context) {
c.Status(204) c.Status(204)
} }
// CreateShareLink generates a public share link // CreateShareLink generates a public share link
func (h *ShareHandler) CreateShareLink(c *gin.Context) { func (h *ShareHandler) CreateShareLink(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id")) documentID, err := uuid.Parse(c.Param("id"))
@@ -290,4 +291,4 @@ func (h *ShareHandler) RevokeShareLink(c *gin.Context) {
// c.JSON(http.StatusOK, gin.H{"message": "Share link revoked successfully"}) // c.JSON(http.StatusOK, gin.H{"message": "Share link revoked successfully"})
c.Status(204) c.Status(204)
} }

View File

@@ -70,7 +70,7 @@ func TestShareHandlerSuite(t *testing.T) {
func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() { func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() {
body := map[string]interface{}{ body := map[string]interface{}{
"user_email": "bob@test.com", "user_email": "bob@test.com",
"permission": "view", "permission": "view",
} }
@@ -87,7 +87,7 @@ func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() {
func (s *ShareHandlerSuite) TestCreateShare_EditPermission() { func (s *ShareHandlerSuite) TestCreateShare_EditPermission() {
body := map[string]interface{}{ body := map[string]interface{}{
"user_email": "bob@test.com", "user_email": "bob@test.com",
"permission": "edit", "permission": "edit",
} }
@@ -104,7 +104,7 @@ func (s *ShareHandlerSuite) TestCreateShare_EditPermission() {
func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() { func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() {
body := map[string]interface{}{ body := map[string]interface{}{
"user_email": "charlie@test.com", "user_email": "charlie@test.com",
"permission": "view", "permission": "view",
} }
@@ -118,7 +118,7 @@ func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() {
func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() { func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() {
body := map[string]interface{}{ body := map[string]interface{}{
"user_email": "nonexistent@test.com", "user_email": "nonexistent@test.com",
"permission": "view", "permission": "view",
} }
@@ -131,7 +131,7 @@ func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() {
func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() { func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() {
body := map[string]interface{}{ body := map[string]interface{}{
"user_email": "bob@test.com", "user_email": "bob@test.com",
"permission": "admin", // Invalid permission "permission": "admin", // Invalid permission
} }
@@ -145,7 +145,7 @@ func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() {
func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() { func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() {
// Create initial share with view permission // Create initial share with view permission
body := map[string]interface{}{ body := map[string]interface{}{
"user_email": "bob@test.com", "user_email": "bob@test.com",
"permission": "view", "permission": "view",
} }
@@ -169,7 +169,7 @@ func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() {
func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() { func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() {
body := map[string]interface{}{ body := map[string]interface{}{
"user_email": "bob@test.com", "user_email": "bob@test.com",
"permission": "view", "permission": "view",
} }
@@ -182,7 +182,7 @@ func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() {
func (s *ShareHandlerSuite) TestCreateShare_InvalidDocumentID() { func (s *ShareHandlerSuite) TestCreateShare_InvalidDocumentID() {
body := map[string]interface{}{ body := map[string]interface{}{
"user_email": "bob@test.com", "user_email": "bob@test.com",
"permission": "view", "permission": "view",
} }
@@ -206,8 +206,8 @@ func (s *ShareHandlerSuite) TestListShares_OwnerSeesAll() {
s.assertSuccessResponse(w, http.StatusOK) s.assertSuccessResponse(w, http.StatusOK)
var response models.ShareListResponse var response models.ShareListResponse
s.parseJSONResponse(w, &response) s.parseJSONResponse(w, &response)
shares := response.Shares shares := response.Shares
s.GreaterOrEqual(len(shares), 1, "Should have at least one share") s.GreaterOrEqual(len(shares), 1, "Should have at least one share")
} }
@@ -229,8 +229,8 @@ func (s *ShareHandlerSuite) TestListShares_EmptyList() {
s.assertSuccessResponse(w, http.StatusOK) s.assertSuccessResponse(w, http.StatusOK)
var response models.ShareListResponse var response models.ShareListResponse
s.parseJSONResponse(w, &response) s.parseJSONResponse(w, &response)
shares := response.Shares shares := response.Shares
s.Equal(0, len(shares), "Should have no shares") s.Equal(0, len(shares), "Should have no shares")
} }
@@ -243,8 +243,8 @@ func (s *ShareHandlerSuite) TestListShares_IncludesUserDetails() {
s.assertSuccessResponse(w, http.StatusOK) s.assertSuccessResponse(w, http.StatusOK)
var response models.ShareListResponse var response models.ShareListResponse
s.parseJSONResponse(w, &response) s.parseJSONResponse(w, &response)
shares := response.Shares shares := response.Shares
if len(shares) > 0 { if len(shares) > 0 {
share := shares[0] share := shares[0]
@@ -266,7 +266,7 @@ func (s *ShareHandlerSuite) TestListShares_OrderedByCreatedAt() {
users := []string{"bob@test.com", "charlie@test.com"} users := []string{"bob@test.com", "charlie@test.com"}
for _, email := range users { for _, email := range users {
body := map[string]interface{}{ body := map[string]interface{}{
"user_email": email, "user_email": email,
"permission": "view", "permission": "view",
} }
w, httpReq, err := s.makeAuthRequest("POST", fmt.Sprintf("/api/documents/%s/shares", s.testData.AlicePrivateDoc), body, s.testData.AliceID) w, httpReq, err := s.makeAuthRequest("POST", fmt.Sprintf("/api/documents/%s/shares", s.testData.AlicePrivateDoc), body, s.testData.AliceID)

View File

@@ -48,10 +48,10 @@ func (wsh *WebSocketHandler) HandleWebSocketLoadTest(c *gin.Context) {
clientID := uuid.New().String() clientID := uuid.New().String()
client := hub.NewClient( client := hub.NewClient(
clientID, clientID,
nil, // userID - nil for anonymous nil, // userID - nil for anonymous
userName, // userName userName, // userName
nil, // userAvatar nil, // userAvatar
"edit", // permission - full access for load testing "edit", // permission - full access for load testing
conn, conn,
wsh.hub, wsh.hub,
roomID, roomID,

View File

@@ -14,28 +14,26 @@ const (
) )
type Document struct { type Document struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Type DocumentType `json:"type"` Type DocumentType `json:"type"`
YjsState []byte `json:"-"` YjsState []byte `json:"-"`
OwnerID *uuid.UUID `json:"owner_id"` // NEW OwnerID *uuid.UUID `json:"owner_id"` // NEW
Is_Public bool `json:"is_public"` Is_Public bool `json:"is_public"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }
type CreateDocumentRequest struct { type CreateDocumentRequest struct {
Name string `json:"name" binding:"required"` Name string `json:"name" binding:"required"`
Type DocumentType `json:"type" binding:"required"` Type DocumentType `json:"type" binding:"required"`
} }
type UpdateStateRequest struct { type UpdateStateRequest struct {
State []byte `json:"state" binding:"required"` State []byte `json:"state" binding:"required"`
} }
type DocumentListResponse struct {
Documents []Document `json:"documents"`
Total int `json:"total"`
}
type DocumentListResponse struct {
Documents []Document `json:"documents"`
Total int `json:"total"`
}

View File

@@ -7,30 +7,30 @@ import (
) )
type DocumentShare struct { type DocumentShare struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
DocumentID uuid.UUID `json:"document_id"` DocumentID uuid.UUID `json:"document_id"`
UserID uuid.UUID `json:"user_id"` UserID uuid.UUID `json:"user_id"`
Permission string `json:"permission"` // "view" or "edit" Permission string `json:"permission"` // "view" or "edit"
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
CreatedBy *uuid.UUID `json:"created_by"` CreatedBy *uuid.UUID `json:"created_by"`
} }
type CreateShareRequest struct { type CreateShareRequest struct {
UserEmail string `json:"user_email" binding:"required"` UserEmail string `json:"user_email" binding:"required"`
Permission string `json:"permission" binding:"required,oneof=view edit"` Permission string `json:"permission" binding:"required,oneof=view edit"`
} }
type ShareListResponse struct { type ShareListResponse struct {
Shares []DocumentShareWithUser `json:"shares"` Shares []DocumentShareWithUser `json:"shares"`
} }
type DocumentShareWithUser struct { type DocumentShareWithUser struct {
DocumentShare DocumentShare
User User `json:"user"` User User `json:"user"`
} }
// PermissionResponse represents the user's permission level for a document // PermissionResponse represents the user's permission level for a document
type PermissionResponse struct { type PermissionResponse struct {
Permission string `json:"permission"` // "view" or "edit" Permission string `json:"permission"` // "view" or "edit"
Role string `json:"role"` // "owner", "editor", or "viewer" Role string `json:"role"` // "owner", "editor", or "viewer"
} }

View File

@@ -7,42 +7,42 @@ import (
) )
type User struct { type User struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
Email string `json:"email"` Email string `json:"email"`
Name string `json:"name"` Name string `json:"name"`
AvatarURL *string `json:"avatar_url"` AvatarURL *string `json:"avatar_url"`
Provider string `json:"provider"` Provider string `json:"provider"`
ProviderUserID string `json:"-"` // Don't expose ProviderUserID string `json:"-"` // Don't expose
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
LastLoginAt *time.Time `json:"last_login_at"` LastLoginAt *time.Time `json:"last_login_at"`
} }
type Session struct { type Session struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"` UserID uuid.UUID `json:"user_id"`
TokenHash string `json:"-"` // SHA-256 hash of JWT TokenHash string `json:"-"` // SHA-256 hash of JWT
ExpiresAt time.Time `json:"expires_at"` ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UserAgent *string `json:"user_agent"` UserAgent *string `json:"user_agent"`
IPAddress *string `json:"ip_address"` IPAddress *string `json:"ip_address"`
} }
type OAuthToken struct { type OAuthToken struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"` UserID uuid.UUID `json:"user_id"`
Provider string `json:"provider"` Provider string `json:"provider"`
AccessToken string `json:"-"` // Don't expose AccessToken string `json:"-"` // Don't expose
RefreshToken *string `json:"-"` RefreshToken *string `json:"-"`
TokenType string `json:"token_type"` TokenType string `json:"token_type"`
ExpiresAt time.Time `json:"expires_at"` ExpiresAt time.Time `json:"expires_at"`
Scope *string `json:"scope"` Scope *string `json:"scope"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"` UpdatedAt time.Time `json:"updated_at"`
} }
// Response for /auth/me endpoint // Response for /auth/me endpoint
type UserResponse struct { type UserResponse struct {
User *User `json:"user"` User *User `json:"user"`
Token string `json:"token,omitempty"` Token string `json:"token,omitempty"`
} }

View File

@@ -10,8 +10,8 @@ import (
type DocumentVersion struct { type DocumentVersion struct {
ID uuid.UUID `json:"id"` ID uuid.UUID `json:"id"`
DocumentID uuid.UUID `json:"document_id"` DocumentID uuid.UUID `json:"document_id"`
YjsSnapshot []byte `json:"-"` // Omit from JSON (binary) YjsSnapshot []byte `json:"-"` // Omit from JSON (binary)
TextPreview *string `json:"text_preview"` // Full plain text TextPreview *string `json:"text_preview"` // Full plain text
VersionNumber int `json:"version_number"` VersionNumber int `json:"version_number"`
CreatedBy *uuid.UUID `json:"created_by"` CreatedBy *uuid.UUID `json:"created_by"`
VersionLabel *string `json:"version_label"` VersionLabel *string `json:"version_label"`

View File

@@ -13,33 +13,33 @@ import (
// Store interface defines all database operations // Store interface defines all database operations
type Store interface { type Store interface {
// Document operations // Document operations
CreateDocument(name string, docType models.DocumentType) (*models.Document, error) CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
GetDocument(id uuid.UUID) (*models.Document, error) GetDocument(id uuid.UUID) (*models.Document, error)
ListDocuments() ([]models.Document, error) ListDocuments() ([]models.Document, error)
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
UpdateDocumentState(id uuid.UUID, state []byte) error UpdateDocumentState(id uuid.UUID, state []byte) error
DeleteDocument(id uuid.UUID) error DeleteDocument(id uuid.UUID) error
// User operations // User operations
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error) GetUserByEmail(ctx context.Context, email string) (*models.User, error)
// Session operations // Session operations
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
GetSessionByToken(ctx context.Context, token string) (*models.Session, error) GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
DeleteSession(ctx context.Context, token string) error DeleteSession(ctx context.Context, token string) error
CleanupExpiredSessions(ctx context.Context) error CleanupExpiredSessions(ctx context.Context) error
// Share operations // Share operations
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error)
ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error) ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error)
RevokeShareToken(ctx context.Context, documentID uuid.UUID) error RevokeShareToken(ctx context.Context, documentID uuid.UUID) error
@@ -53,13 +53,11 @@ type Store interface {
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error) GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error) GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
Close() error
Close() error
} }
type PostgresStore struct { type PostgresStore struct {
db *sql.DB db *sql.DB
} }
func NewPostgresStore(databaseUrl string) (*PostgresStore, error) { func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
@@ -68,17 +66,17 @@ func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
return nil, error return nil, error
} }
if err := db.Ping(); err != nil { if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err) return nil, fmt.Errorf("failed to ping database: %w", err)
} }
db.SetMaxOpenConns(25) db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5) db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute) db.SetConnMaxLifetime(5 * time.Minute)
return &PostgresStore{db: db}, nil return &PostgresStore{db: db}, nil
} }
func (s *PostgresStore) Close() error { func (s *PostgresStore) Close() error {
return s.db.Close() return s.db.Close()
} }
func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) { func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
doc := &models.Document{ doc := &models.Document{
@@ -95,12 +93,11 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
RETURNING id, name, type, created_at, updated_at RETURNING id, name, type, created_at, updated_at
` `
err := s.db.QueryRow(query, err := s.db.QueryRow(query,
doc.ID, doc.ID,
doc.Name, doc.Name,
doc.Type, doc.Type,
doc.CreatedAt, doc.CreatedAt,
doc.UpdatedAt, doc.UpdatedAt,
).Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt) ).Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil { if err != nil {
@@ -109,164 +106,163 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
return doc, nil return doc, nil
} }
// GetDocument retrieves a document by ID // GetDocument retrieves a document by ID
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) { func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
doc := &models.Document{} doc := &models.Document{}
query := ` query := `
SELECT id, name, type, yjs_state, owner_id, is_public, created_at, updated_at SELECT id, name, type, yjs_state, owner_id, is_public, created_at, updated_at
FROM documents FROM documents
WHERE id = $1 WHERE id = $1
` `
err := s.db.QueryRow(query, id).Scan( err := s.db.QueryRow(query, id).Scan(
&doc.ID, &doc.ID,
&doc.Name, &doc.Name,
&doc.Type, &doc.Type,
&doc.YjsState, &doc.YjsState,
&doc.OwnerID, &doc.OwnerID,
&doc.Is_Public, &doc.Is_Public,
&doc.CreatedAt, &doc.CreatedAt,
&doc.UpdatedAt, &doc.UpdatedAt,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, fmt.Errorf("document not found") return nil, fmt.Errorf("document not found")
} }
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get document: %w", err) return nil, fmt.Errorf("failed to get document: %w", err)
} }
return doc, nil return doc, nil
} }
// ListDocuments retrieves all documents
// ListDocuments retrieves all documents func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
func (s *PostgresStore) ListDocuments() ([]models.Document, error) { query := `
query := `
SELECT id, name, type, created_at, updated_at SELECT id, name, type, created_at, updated_at
FROM documents FROM documents
ORDER BY created_at DESC ORDER BY created_at DESC
` `
rows, err := s.db.Query(query) rows, err := s.db.Query(query)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to list documents: %w", err) return nil, fmt.Errorf("failed to list documents: %w", err)
} }
defer rows.Close() defer rows.Close()
var documents []models.Document var documents []models.Document
for rows.Next() { for rows.Next() {
var doc models.Document var doc models.Document
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt) err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to scan document: %w", err) return nil, fmt.Errorf("failed to scan document: %w", err)
} }
documents = append(documents, doc) documents = append(documents, doc)
} }
return documents, nil return documents, nil
} }
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error { func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
query := ` query := `
UPDATE documents UPDATE documents
SET yjs_state = $1, updated_at = $2 SET yjs_state = $1, updated_at = $2
WHERE id = $3 WHERE id = $3
` `
result, err := s.db.Exec(query, state, time.Now(), id) result, err := s.db.Exec(query, state, time.Now(), id)
if err != nil { if err != nil {
return fmt.Errorf("failed to update document state: %w", err) return fmt.Errorf("failed to update document state: %w", err)
} }
rowsAffected, err := result.RowsAffected() rowsAffected, err := result.RowsAffected()
if err != nil { if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err) return fmt.Errorf("failed to get rows affected: %w", err)
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return fmt.Errorf("document not found") return fmt.Errorf("document not found")
} }
return nil return nil
} }
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error { func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
query := `DELETE FROM documents WHERE id = $1` query := `DELETE FROM documents WHERE id = $1`
result, err := s.db.Exec(query, id) result, err := s.db.Exec(query, id)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete document: %w", err) return fmt.Errorf("failed to delete document: %w", err)
} }
rowsAffected, err := result.RowsAffected() rowsAffected, err := result.RowsAffected()
if err != nil { if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err) return fmt.Errorf("failed to get rows affected: %w", err)
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return fmt.Errorf("document not found") return fmt.Errorf("document not found")
} }
return nil return nil
} }
// CreateDocumentWithOwner creates a new document with owner // CreateDocumentWithOwner creates a new document with owner
func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) { func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) {
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错) // 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
if docType == "" { if docType == "" {
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text" docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
} }
// Validate that docType is one of the allowed values // Validate that docType is one of the allowed values
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban { if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType) return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
} }
doc := &models.Document{ doc := &models.Document{
ID: uuid.New(), ID: uuid.New(),
Name: name, Name: name,
Type: docType, Type: docType,
YjsState: []byte{}, // 这里初始化了空字节 YjsState: []byte{}, // 这里初始化了空字节
OwnerID: ownerID, OwnerID: ownerID,
Is_Public: false, // 显式设置默认值 Is_Public: false, // 显式设置默认值
CreatedAt: time.Now(), CreatedAt: time.Now(),
UpdatedAt: time.Now(), UpdatedAt: time.Now(),
} }
// 2. 补全了 yjs_state 和 is_public // 2. 补全了 yjs_state 和 is_public
query := ` query := `
INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at) INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at
` `
// 3. Scan 的时候也要对应加上
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.OwnerID,
doc.YjsState, // $5
doc.Is_Public, // $6
doc.CreatedAt,
doc.UpdatedAt,
).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.OwnerID,
&doc.YjsState, // Scan 回来
&doc.Is_Public, // Scan 回来
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil { // 3. Scan 的时候也要对应加上
return nil, fmt.Errorf("failed to create document: %w", err) err := s.db.QueryRow(query,
} doc.ID,
return doc, nil doc.Name,
doc.Type,
doc.OwnerID,
doc.YjsState, // $5
doc.Is_Public, // $6
doc.CreatedAt,
doc.UpdatedAt,
).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.OwnerID,
&doc.YjsState, // Scan 回来
&doc.Is_Public, // Scan 回来
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create document: %w", err)
}
return doc, nil
} }
// ListUserDocuments lists documents owned by or shared with a user // ListUserDocuments lists documents owned by or shared with a user

View File

@@ -12,77 +12,77 @@ import (
// CreateSession creates a new session // CreateSession creates a new session
func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) { func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) {
// Hash the token before storing // Hash the token before storing
hash := sha256.Sum256([]byte(token)) hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:]) tokenHash := hex.EncodeToString(hash[:])
// 【修改点 1】: 在 SQL 里显式加上 id 字段 // 【修改点 1】: 在 SQL 里显式加上 id 字段
// 注意:$1 变成了 id后面的参数序号全部要顺延 (+1) // 注意:$1 变成了 id后面的参数序号全部要顺延 (+1)
query := ` query := `
INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address) INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address)
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
` `
var session models.Session var session models.Session
// 【修改点 2】: 在参数列表的最前面加上 sessionID // 【修改点 2】: 在参数列表的最前面加上 sessionID
// 现在的对应关系: // 现在的对应关系:
// $1 -> sessionID // $1 -> sessionID
// $2 -> userID // $2 -> userID
// $3 -> tokenHash // $3 -> tokenHash
// ... // ...
err := s.db.QueryRowContext(ctx, query, err := s.db.QueryRowContext(ctx, query,
sessionID, // <--- 这里!把它传进去! sessionID, // <--- 这里!把它传进去!
userID, userID,
tokenHash, tokenHash,
expiresAt, expiresAt,
userAgent, userAgent,
ipAddress, ipAddress,
).Scan( ).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt, &session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress, &session.CreatedAt, &session.UserAgent, &session.IPAddress,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &session, nil return &session, nil
} }
// GetSessionByToken retrieves session by JWT token // GetSessionByToken retrieves session by JWT token
func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) { func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) {
hash := sha256.Sum256([]byte(token)) hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:]) tokenHash := hex.EncodeToString(hash[:])
query := ` query := `
SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
FROM sessions FROM sessions
WHERE token_hash = $1 AND expires_at > NOW() WHERE token_hash = $1 AND expires_at > NOW()
` `
var session models.Session var session models.Session
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan( err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt, &session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress, &session.CreatedAt, &session.UserAgent, &session.IPAddress,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &session, nil return &session, nil
} }
// DeleteSession deletes a session (logout) // DeleteSession deletes a session (logout)
func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error { func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token)) hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:]) tokenHash := hex.EncodeToString(hash[:])
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash) _, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
return err return err
} }
// CleanupExpiredSessions removes expired sessions // CleanupExpiredSessions removes expired sessions
func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error { func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()") _, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
return err return err
} }

View File

@@ -14,34 +14,34 @@ import (
// CreateDocumentShare creates a new share or updates existing one // CreateDocumentShare creates a new share or updates existing one
// Returns the share and a boolean indicating if it was newly created (true) or updated (false) // Returns the share and a boolean indicating if it was newly created (true) or updated (false)
func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) { func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) {
// First check if share already exists // First check if share already exists
var existingID uuid.UUID var existingID uuid.UUID
checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2` checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2`
err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID) err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID)
isNewShare := err != nil // If error (not found), it's a new share isNewShare := err != nil // If error (not found), it's a new share
query := ` query := `
INSERT INTO document_shares (document_id, user_id, permission, created_by) INSERT INTO document_shares (document_id, user_id, permission, created_by)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission
RETURNING id, document_id, user_id, permission, created_at, created_by RETURNING id, document_id, user_id, permission, created_at, created_by
` `
var share models.DocumentShare var share models.DocumentShare
err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan( err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.ID, &share.DocumentID, &share.UserID, &share.Permission,
&share.CreatedAt, &share.CreatedBy, &share.CreatedAt, &share.CreatedBy,
) )
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }
return &share, isNewShare, nil return &share, isNewShare, nil
} }
// ListDocumentShares lists all shares for a document // ListDocumentShares lists all shares for a document
func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) { func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) {
query := ` query := `
SELECT SELECT
ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by, ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by,
u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at
@@ -51,38 +51,38 @@ func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.
ORDER BY ds.created_at DESC ORDER BY ds.created_at DESC
` `
rows, err := s.db.QueryContext(ctx, query, documentID) rows, err := s.db.QueryContext(ctx, query, documentID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
var shares []models.DocumentShareWithUser var shares []models.DocumentShareWithUser
for rows.Next() { for rows.Next() {
var share models.DocumentShareWithUser var share models.DocumentShareWithUser
err := rows.Scan( err := rows.Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy, &share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider, &share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt, &share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
shares = append(shares, share) shares = append(shares, share)
} }
return shares, nil return shares, nil
} }
// DeleteDocumentShare deletes a share // DeleteDocumentShare deletes a share
func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error { func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID) _, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
return err return err
} }
// CanViewDocument checks if user can view document (owner OR has any share) // CanViewDocument checks if user can view document (owner OR has any share)
func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) { func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := ` query := `
SELECT EXISTS( SELECT EXISTS(
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2 SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
UNION UNION
@@ -90,14 +90,14 @@ func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID
) )
` `
var canView bool var canView bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView) err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
return canView, err return canView, err
} }
// CanEditDocument checks if user can edit document (owner OR has edit share) // CanEditDocument checks if user can edit document (owner OR has edit share)
func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) { func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := ` query := `
SELECT EXISTS( SELECT EXISTS(
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2 SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
UNION UNION
@@ -105,21 +105,21 @@ func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID
) )
` `
var canEdit bool var canEdit bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit) err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
return canEdit, err return canEdit, err
} }
// IsDocumentOwner checks if user is the owner // IsDocumentOwner checks if user is the owner
func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) { func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := `SELECT owner_id = $2 FROM documents WHERE id = $1` query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
var isOwner bool var isOwner bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner) err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return false, nil return false, nil
} }
return isOwner, err return isOwner, err
} }
func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) { func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) {
// Generate random 32-byte token // Generate random 32-byte token
@@ -239,4 +239,4 @@ func (s *PostgresStore) GetShareLinkPermission(ctx context.Context, documentID u
} }
return permission, nil return permission, nil
} }

View File

@@ -11,7 +11,7 @@ import (
// UpsertUser creates or updates user from OAuth profile // UpsertUser creates or updates user from OAuth profile
func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) { func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) {
query := ` query := `
INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at) INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at)
VALUES ($1, $2, $3, $4, $5, NOW()) VALUES ($1, $2, $3, $4, $5, NOW())
ON CONFLICT (provider, provider_user_id) ON CONFLICT (provider, provider_user_id)
@@ -24,59 +24,59 @@ func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID
RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
` `
var user models.User var user models.User
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan( err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email) fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
return &user, nil return &user, nil
} }
// GetUserByID retrieves user by ID // GetUserByID retrieves user by ID
func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) { func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
query := ` query := `
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
FROM users WHERE id = $1 FROM users WHERE id = $1
` `
var user models.User var user models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan( err := s.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &user, nil return &user, nil
} }
// GetUserByEmail retrieves user by email // GetUserByEmail retrieves user by email
func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
query := ` query := `
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
FROM users WHERE email = $1 FROM users WHERE email = $1
` `
var user models.User var user models.User
err := s.db.QueryRowContext(ctx, query, email).Scan( err := s.db.QueryRowContext(ctx, query, email).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
) )
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &user, nil return &user, nil
} }