diff --git a/.gitignore b/.gitignore index c88a604..8192af3 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,5 @@ build/ # Docker volumes and data postgres_data/ + +.claude/ \ No newline at end of file diff --git a/backend/internal/auth/jwt.go b/backend/internal/auth/jwt.go index 9f4cb84..ddf8e01 100644 --- a/backend/internal/auth/jwt.go +++ b/backend/internal/auth/jwt.go @@ -60,4 +60,4 @@ func ValidateJWT(tokenString, secret string) (*UserClaims, error) { } return nil, errors.New("invalid token claims") -} \ No newline at end of file +} diff --git a/backend/internal/auth/middleware.go b/backend/internal/auth/middleware.go index e5d076f..9789b04 100644 --- a/backend/internal/auth/middleware.go +++ b/backend/internal/auth/middleware.go @@ -18,142 +18,142 @@ const ContextUserIDKey = "user_id" // AuthMiddleware provides auth middleware type AuthMiddleware struct { - store store.Store - jwtSecret string + store store.Store + jwtSecret string } // NewAuthMiddleware creates a new auth middleware func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware { - return &AuthMiddleware{ - store: store, - jwtSecret: jwtSecret, - } + return &AuthMiddleware{ + store: store, + jwtSecret: jwtSecret, + } } // RequireAuth middleware requires valid authentication func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc { - return func(c *gin.Context) { - fmt.Println("🔒 RequireAuth: Starting authentication check") + return func(c *gin.Context) { + fmt.Println("🔒 RequireAuth: Starting authentication check") - user, claims, err := m.getUserFromToken(c) + user, claims, err := m.getUserFromToken(c) - fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err) - if claims != nil { - fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email) - } + fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err) + if claims != nil { + fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email) + } - if err != nil || user == nil { - fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user) - c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) - c.Abort() - return - } + if err != nil || user == nil { + fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user) + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } - // Note: Name and Email might be empty for old JWT tokens - if claims.Name == "" || claims.Email == "" { - fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n") - } + // Note: Name and Email might be empty for old JWT tokens + if claims.Name == "" || claims.Email == "" { + fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n") + } - fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user) - c.Set(ContextUserIDKey, user) - c.Set("user_email", claims.Email) - c.Set("user_name", claims.Name) - if claims.AvatarURL != nil { - c.Set("avatar_url", *claims.AvatarURL) - } - c.Next() - } + fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user) + c.Set(ContextUserIDKey, user) + c.Set("user_email", claims.Email) + c.Set("user_name", claims.Name) + if claims.AvatarURL != nil { + c.Set("avatar_url", *claims.AvatarURL) + } + c.Next() + } } // OptionalAuth middleware sets user if authenticated, but doesn't require it func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc { - return func(c *gin.Context) { - user, claims, _ := m.getUserFromToken(c) - if user != nil { - c.Set(string(UserContextKey), user) - c.Set(ContextUserIDKey, user) - if claims != nil { - c.Set("user_email", claims.Email) - c.Set("user_name", claims.Name) - if claims.AvatarURL != nil { - c.Set("avatar_url", *claims.AvatarURL) - } - } - } - c.Next() - } + return func(c *gin.Context) { + user, claims, _ := m.getUserFromToken(c) + if user != nil { + c.Set(string(UserContextKey), user) + c.Set(ContextUserIDKey, user) + if claims != nil { + c.Set("user_email", claims.Email) + c.Set("user_name", claims.Name) + if claims.AvatarURL != nil { + c.Set("avatar_url", *claims.AvatarURL) + } + } + } + c.Next() + } } // getUserFromToken parses the JWT and returns the UserID and the full Claims (for name/email) // 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error) func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) { - authHeader := c.GetHeader("Authorization") - fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader) + authHeader := c.GetHeader("Authorization") + fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader) - if authHeader == "" { - fmt.Println("⚠️ getUserFromToken: No Authorization header") - return nil, nil, nil - } + if authHeader == "" { + fmt.Println("⚠️ getUserFromToken: No Authorization header") + return nil, nil, nil + } - parts := strings.Split(authHeader, " ") - if len(parts) != 2 || parts[0] != "Bearer" { - fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0]) - return nil, nil, nil - } + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0]) + return nil, nil, nil + } - tokenString := parts[1] - fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))]) + tokenString := parts[1] + fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))]) - token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { - // 必须要验证签名算法是 HMAC (HS256) - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - return []byte(m.jwtSecret), nil - }) + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { + // 必须要验证签名算法是 HMAC (HS256) + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(m.jwtSecret), nil + }) - if err != nil { - fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err) - return nil, nil, err - } + if err != nil { + fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err) + return nil, nil, err + } - // 2. 验证 Token 有效性并提取 Claims - if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { - // 3. 把 String 类型的 Subject 转回 UUID - // 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String() - userID, err := uuid.Parse(claims.Subject) - if err != nil { - fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err) - return nil, nil, fmt.Errorf("invalid user ID in token") - } + // 2. 验证 Token 有效性并提取 Claims + if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { + // 3. 把 String 类型的 Subject 转回 UUID + // 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String() + userID, err := uuid.Parse(claims.Subject) + if err != nil { + fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err) + return nil, nil, fmt.Errorf("invalid user ID in token") + } - // 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email) - // 这一步完全没有查数据库,速度极快 - fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email) - return &userID, claims, nil - } + // 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email) + // 这一步完全没有查数据库,速度极快 + fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email) + return &userID, claims, nil + } - fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid") - return nil, nil, fmt.Errorf("invalid token claims") + fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid") + return nil, nil, fmt.Errorf("invalid token claims") } // GetUserFromContext extracts user ID from context func GetUserFromContext(c *gin.Context) *uuid.UUID { - // 修正点:使用和存入时完全一样的 Key - val, exists := c.Get(ContextUserIDKey) - fmt.Println("within getFromContext the id is ... ") - fmt.Println(val); - if !exists { - return nil - } + // 修正点:使用和存入时完全一样的 Key + val, exists := c.Get(ContextUserIDKey) + fmt.Println("within getFromContext the id is ... ") + fmt.Println(val) + if !exists { + return nil + } - // 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型) - uid, ok := val.(*uuid.UUID) - if !ok { - return nil - } + // 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型) + uid, ok := val.(*uuid.UUID) + if !ok { + return nil + } - return uid + return uid } // ValidateToken validates a JWT token and returns user ID, name, and avatar URL from JWT claims diff --git a/backend/internal/auth/oauth.go b/backend/internal/auth/oauth.go index 5dad179..1a5a887 100644 --- a/backend/internal/auth/oauth.go +++ b/backend/internal/auth/oauth.go @@ -8,25 +8,25 @@ import ( // GetGoogleOAuthConfig returns Google OAuth2 config func GetGoogleOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config { - return &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - RedirectURL: redirectURL, - Scopes: []string{ - "https://www.googleapis.com/auth/userinfo.email", - "https://www.googleapis.com/auth/userinfo.profile", - }, - Endpoint: google.Endpoint, - } + return &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURL, + Scopes: []string{ + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + }, + Endpoint: google.Endpoint, + } } // GetGitHubOAuthConfig returns GitHub OAuth2 config func GetGitHubOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config { - return &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - RedirectURL: redirectURL, - Scopes: []string{"user:email"}, - Endpoint: github.Endpoint, - } + return &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURL, + Scopes: []string{"user:email"}, + Endpoint: github.Endpoint, + } } diff --git a/backend/internal/handlers/auth.go b/backend/internal/handlers/auth.go index bd515fc..687b7d7 100644 --- a/backend/internal/handlers/auth.go +++ b/backend/internal/handlers/auth.go @@ -64,10 +64,10 @@ func (h *AuthHandler) GoogleLogin(c *gin.Context) { // GoogleCallback handles Google OAuth callback func (h *AuthHandler) GoogleCallback(c *gin.Context) { oauthState, err := c.Cookie("oauthstate") - if err != nil || c.Query("state") != oauthState { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"}) - return - } + if err != nil || c.Query("state") != oauthState { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"}) + return + } log.Println("Google callback state:", c.Query("state")) // Exchange code for token token, err := h.googleConfig.Exchange(c.Request.Context(), c.Query("code")) @@ -94,11 +94,11 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) { Name string `json:"name"` Picture string `json:"picture"` } - + if err := json.Unmarshal(data, &userInfo); err != nil { - log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"}) - return + log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"}) + return } log.Println("Google user info:", userInfo) // Upsert user in database @@ -118,10 +118,10 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) { // Create session and JWT jwt, err := h.createSessionAndJWT(c, user) if err != nil { - fmt.Printf("❌ DATABASE ERROR: %v\n", err) - c.JSON(http.StatusInternalServerError, gin.H{ - "error": fmt.Sprintf("CreateSession Error: %v", err), - }) + fmt.Printf("❌ DATABASE ERROR: %v\n", err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": fmt.Sprintf("CreateSession Error: %v", err), + }) return } @@ -140,10 +140,10 @@ func (h *AuthHandler) GithubLogin(c *gin.Context) { // GithubCallback handles GitHub OAuth callback func (h *AuthHandler) GithubCallback(c *gin.Context) { oauthState, err := c.Cookie("oauthstate") - if err != nil || c.Query("state") != oauthState { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"}) - return - } + if err != nil || c.Query("state") != oauthState { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"}) + return + } log.Println("Github callback state:", c.Query("state")) code := c.Query("code") if code == "" { @@ -160,7 +160,7 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) { // Get user info from GitHub client := h.githubConfig.Client(c.Request.Context(), token) - + // Get user profile resp, err := client.Get("https://api.github.com/user") if err != nil { @@ -178,10 +178,10 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) { AvatarURL string `json:"avatar_url"` } if err := json.Unmarshal(data, &userInfo); err != nil { - log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data)) - c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"}) - return -} + log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data)) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"}) + return + } // If email is not public, fetch it separately if userInfo.Email == "" { @@ -315,10 +315,10 @@ func (h *AuthHandler) generateStateOauthCookie(w http.ResponseWriter) string { Name: "oauthstate", Value: state, Expires: time.Now().Add(10 * time.Minute), - HttpOnly: true, // Prevents JavaScript access (XSS protection) - Secure: h.cfg.SecureCookie, // true in production, false for localhost - SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects - Path: "/", // Ensures cookie is sent to all backend paths + HttpOnly: true, // Prevents JavaScript access (XSS protection) + Secure: h.cfg.SecureCookie, // true in production, false for localhost + SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects + Path: "/", // Ensures cookie is sent to all backend paths } http.SetCookie(w, &cookie) diff --git a/backend/internal/handlers/document.go b/backend/internal/handlers/document.go index cecb183..3069885 100644 --- a/backend/internal/handlers/document.go +++ b/backend/internal/handlers/document.go @@ -11,16 +11,15 @@ import ( "github.com/google/uuid" ) - type DocumentHandler struct { - store *store.PostgresStore - } +type DocumentHandler struct { + store *store.PostgresStore +} - func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler { - return &DocumentHandler{store: s} - } +func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler { + return &DocumentHandler{store: s} +} - - // CreateDocument creates a new document (requires auth) +// CreateDocument creates a new document (requires auth) func (h *DocumentHandler) CreateDocument(c *gin.Context) { userID := auth.GetUserFromContext(c) if userID == nil { @@ -44,7 +43,7 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) { c.JSON(http.StatusCreated, doc) } - func (h *DocumentHandler) ListDocuments(c *gin.Context) { +func (h *DocumentHandler) ListDocuments(c *gin.Context) { userID := auth.GetUserFromContext(c) fmt.Println("Getting userId, which is : ") fmt.Println(userID) @@ -66,8 +65,7 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) { }) } - - func (h *DocumentHandler) GetDocument(c *gin.Context) { +func (h *DocumentHandler) GetDocument(c *gin.Context) { id, err := uuid.Parse(c.Param("id")) if err != nil { respondBadRequest(c, "Invalid document ID format") @@ -104,8 +102,9 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) { c.JSON(http.StatusOK, doc) } - // GetDocumentState returns the Yjs state for a document - // GetDocumentState retrieves document state (requires view permission) + +// GetDocumentState returns the Yjs state for a document +// GetDocumentState retrieves document state (requires view permission) func (h *DocumentHandler) GetDocumentState(c *gin.Context) { id, err := uuid.Parse(c.Param("id")) if err != nil { @@ -143,7 +142,7 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) { c.Data(http.StatusOK, "application/octet-stream", state) } - // UpdateDocumentState updates document state (requires edit permission) +// UpdateDocumentState updates document state (requires edit permission) func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) { id, err := uuid.Parse(c.Param("id")) if err != nil { @@ -195,7 +194,7 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"}) } - // DeleteDocument deletes a document (owner only) +// DeleteDocument deletes a document (owner only) func (h *DocumentHandler) DeleteDocument(c *gin.Context) { id, err := uuid.Parse(c.Param("id")) if err != nil { @@ -237,107 +236,109 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) { // GetDocumentPermission returns the user's permission level for a document func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) { - documentID, err := uuid.Parse(c.Param("id")) - if err != nil { - respondBadRequest(c, "Invalid document ID format") - return - } + documentID, err := uuid.Parse(c.Param("id")) + if err != nil { + respondBadRequest(c, "Invalid document ID format") + return + } - // 1. 先检查文档是否存在 (Good practice) - _, err = h.store.GetDocument(documentID) - if err != nil { - respondNotFound(c, "document") - return - } + // 1. 先检查文档是否存在 (Good practice) + _, err = h.store.GetDocument(documentID) + if err != nil { + respondNotFound(c, "document") + return + } - userID := auth.GetUserFromContext(c) - shareToken := c.Query("share") + userID := auth.GetUserFromContext(c) + shareToken := c.Query("share") - // 定义两个临时变量,用来存两边的结果 - var userPerm string // 存 document_shares 的结果 - var tokenPerm string // 存 share_token 的结果 + // 定义两个临时变量,用来存两边的结果 + var userPerm string // 存 document_shares 的结果 + var tokenPerm string // 存 share_token 的结果 - // ==================================================== - // 步骤 A: 检查个人权限 (Base Permission) - // ==================================================== - if userID != nil { - perm, err := h.store.GetUserPermission(c.Request.Context(), documentID, *userID) - if err != nil { - respondInternalError(c, "Failed to get user permission", err) - return - } - userPerm = perm - // ⚠️ 注意:如果 perm 是空,这里不报错!继续往下走! - } + // ==================================================== + // 步骤 A: 检查个人权限 (Base Permission) + // ==================================================== + if userID != nil { + perm, err := h.store.GetUserPermission(c.Request.Context(), documentID, *userID) + if err != nil { + respondInternalError(c, "Failed to get user permission", err) + return + } + userPerm = perm + // ⚠️ 注意:如果 perm 是空,这里不报错!继续往下走! + } - // ==================================================== - // 步骤 B: 检查 Token 权限 (Upgrade Permission) - // ==================================================== - if shareToken != "" { - // 先验证 Token 是否有效 - valid, err := h.store.ValidateShareToken(c.Request.Context(), documentID, shareToken) - if err != nil { - respondInternalError(c, "Failed to validate token", err) - return - } - - // 只有 Token 有效才去取权限 - if valid { - p, err := h.store.GetShareLinkPermission(c.Request.Context(), documentID) - if err != nil { - respondInternalError(c, "Failed to get token permission", err) - return - } - tokenPerm = p - // 处理数据库老数据的 fallback - if tokenPerm == "" { tokenPerm = "view" } - } - } + // ==================================================== + // 步骤 B: 检查 Token 权限 (Upgrade Permission) + // ==================================================== + if shareToken != "" { + // 先验证 Token 是否有效 + valid, err := h.store.ValidateShareToken(c.Request.Context(), documentID, shareToken) + if err != nil { + respondInternalError(c, "Failed to validate token", err) + return + } - // ==================================================== - // 步骤 C: ⚡️ 权限合并与计算 (The Brain) - // ==================================================== - - finalPermission := "" - role := "viewer" // 默认角色 + // 只有 Token 有效才去取权限 + if valid { + p, err := h.store.GetShareLinkPermission(c.Request.Context(), documentID) + if err != nil { + respondInternalError(c, "Failed to get token permission", err) + return + } + tokenPerm = p + // 处理数据库老数据的 fallback + if tokenPerm == "" { + tokenPerm = "view" + } + } + } - // 1. 如果是 Owner,无敌,直接返回 - if userPerm == "owner" { - finalPermission = "edit" - role = "owner" - // 直接返回,不用看 Token 了 - c.JSON(http.StatusOK, models.PermissionResponse{ - Permission: finalPermission, - Role: role, - }) - return - } + // ==================================================== + // 步骤 C: ⚡️ 权限合并与计算 (The Brain) + // ==================================================== - // 2. 比较 User 和 Token,取最大值 - // 逻辑:只要任意一边给了 "edit",那就是 "edit" - if userPerm == "edit" || tokenPerm == "edit" { - finalPermission = "edit" - role = "editor" - } else if userPerm == "view" || tokenPerm == "view" { - finalPermission = "view" - role = "viewer" - } + finalPermission := "" + role := "viewer" // 默认角色 - // ==================================================== - // 步骤 D: 最终判决 - // ==================================================== - if finalPermission == "" { - // 既没个人权限,Token 也不对(或者没 Token) - if userID == nil { - respondUnauthorized(c, "Authentication required") // 没登录且没Token - } else { - respondForbidden(c, "You don't have permission") // 登录了但没权限 - } - return - } + // 1. 如果是 Owner,无敌,直接返回 + if userPerm == "owner" { + finalPermission = "edit" + role = "owner" + // 直接返回,不用看 Token 了 + c.JSON(http.StatusOK, models.PermissionResponse{ + Permission: finalPermission, + Role: role, + }) + return + } - c.JSON(http.StatusOK, models.PermissionResponse{ - Permission: finalPermission, - Role: role, - }) -} \ No newline at end of file + // 2. 比较 User 和 Token,取最大值 + // 逻辑:只要任意一边给了 "edit",那就是 "edit" + if userPerm == "edit" || tokenPerm == "edit" { + finalPermission = "edit" + role = "editor" + } else if userPerm == "view" || tokenPerm == "view" { + finalPermission = "view" + role = "viewer" + } + + // ==================================================== + // 步骤 D: 最终判决 + // ==================================================== + if finalPermission == "" { + // 既没个人权限,Token 也不对(或者没 Token) + if userID == nil { + respondUnauthorized(c, "Authentication required") // 没登录且没Token + } else { + respondForbidden(c, "You don't have permission") // 登录了但没权限 + } + return + } + + c.JSON(http.StatusOK, models.PermissionResponse{ + Permission: finalPermission, + Role: role, + }) +} diff --git a/backend/internal/handlers/errors.go b/backend/internal/handlers/errors.go index 1dd67c1..1739705 100644 --- a/backend/internal/handlers/errors.go +++ b/backend/internal/handlers/errors.go @@ -91,5 +91,5 @@ func respondInternalError(c *gin.Context, message string, err error) { respondWithError(c, http.StatusInternalServerError, "internal_error", message) } func respondInvalidID(c *gin.Context, message string) { - respondWithError(c, http.StatusBadRequest, "invalid_id", message) + respondWithError(c, http.StatusBadRequest, "invalid_id", message) } diff --git a/backend/internal/handlers/share.go b/backend/internal/handlers/share.go index 4ab0bbf..0260a6d 100644 --- a/backend/internal/handlers/share.go +++ b/backend/internal/handlers/share.go @@ -153,6 +153,7 @@ func (h *ShareHandler) DeleteShare(c *gin.Context) { c.Status(204) } + // CreateShareLink generates a public share link func (h *ShareHandler) CreateShareLink(c *gin.Context) { documentID, err := uuid.Parse(c.Param("id")) @@ -290,4 +291,4 @@ func (h *ShareHandler) RevokeShareLink(c *gin.Context) { // c.JSON(http.StatusOK, gin.H{"message": "Share link revoked successfully"}) c.Status(204) -} \ No newline at end of file +} diff --git a/backend/internal/handlers/share_test.go b/backend/internal/handlers/share_test.go index 64d88da..6607825 100644 --- a/backend/internal/handlers/share_test.go +++ b/backend/internal/handlers/share_test.go @@ -70,7 +70,7 @@ func TestShareHandlerSuite(t *testing.T) { func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() { body := map[string]interface{}{ - "user_email": "bob@test.com", + "user_email": "bob@test.com", "permission": "view", } @@ -87,7 +87,7 @@ func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() { func (s *ShareHandlerSuite) TestCreateShare_EditPermission() { body := map[string]interface{}{ - "user_email": "bob@test.com", + "user_email": "bob@test.com", "permission": "edit", } @@ -104,7 +104,7 @@ func (s *ShareHandlerSuite) TestCreateShare_EditPermission() { func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() { body := map[string]interface{}{ - "user_email": "charlie@test.com", + "user_email": "charlie@test.com", "permission": "view", } @@ -118,7 +118,7 @@ func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() { func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() { body := map[string]interface{}{ - "user_email": "nonexistent@test.com", + "user_email": "nonexistent@test.com", "permission": "view", } @@ -131,7 +131,7 @@ func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() { func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() { body := map[string]interface{}{ - "user_email": "bob@test.com", + "user_email": "bob@test.com", "permission": "admin", // Invalid permission } @@ -145,7 +145,7 @@ func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() { func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() { // Create initial share with view permission body := map[string]interface{}{ - "user_email": "bob@test.com", + "user_email": "bob@test.com", "permission": "view", } @@ -169,7 +169,7 @@ func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() { func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() { body := map[string]interface{}{ - "user_email": "bob@test.com", + "user_email": "bob@test.com", "permission": "view", } @@ -182,7 +182,7 @@ func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() { func (s *ShareHandlerSuite) TestCreateShare_InvalidDocumentID() { body := map[string]interface{}{ - "user_email": "bob@test.com", + "user_email": "bob@test.com", "permission": "view", } @@ -206,8 +206,8 @@ func (s *ShareHandlerSuite) TestListShares_OwnerSeesAll() { s.assertSuccessResponse(w, http.StatusOK) var response models.ShareListResponse -s.parseJSONResponse(w, &response) -shares := response.Shares + s.parseJSONResponse(w, &response) + shares := response.Shares s.GreaterOrEqual(len(shares), 1, "Should have at least one share") } @@ -229,8 +229,8 @@ func (s *ShareHandlerSuite) TestListShares_EmptyList() { s.assertSuccessResponse(w, http.StatusOK) var response models.ShareListResponse -s.parseJSONResponse(w, &response) -shares := response.Shares + s.parseJSONResponse(w, &response) + shares := response.Shares s.Equal(0, len(shares), "Should have no shares") } @@ -243,8 +243,8 @@ func (s *ShareHandlerSuite) TestListShares_IncludesUserDetails() { s.assertSuccessResponse(w, http.StatusOK) var response models.ShareListResponse -s.parseJSONResponse(w, &response) -shares := response.Shares + s.parseJSONResponse(w, &response) + shares := response.Shares if len(shares) > 0 { share := shares[0] @@ -266,7 +266,7 @@ func (s *ShareHandlerSuite) TestListShares_OrderedByCreatedAt() { users := []string{"bob@test.com", "charlie@test.com"} for _, email := range users { body := map[string]interface{}{ - "user_email": email, + "user_email": email, "permission": "view", } w, httpReq, err := s.makeAuthRequest("POST", fmt.Sprintf("/api/documents/%s/shares", s.testData.AlicePrivateDoc), body, s.testData.AliceID) diff --git a/backend/internal/handlers/websocket_loadtest.go b/backend/internal/handlers/websocket_loadtest.go index bcb285e..bf5c05d 100644 --- a/backend/internal/handlers/websocket_loadtest.go +++ b/backend/internal/handlers/websocket_loadtest.go @@ -48,10 +48,10 @@ func (wsh *WebSocketHandler) HandleWebSocketLoadTest(c *gin.Context) { clientID := uuid.New().String() client := hub.NewClient( clientID, - nil, // userID - nil for anonymous - userName, // userName - nil, // userAvatar - "edit", // permission - full access for load testing + nil, // userID - nil for anonymous + userName, // userName + nil, // userAvatar + "edit", // permission - full access for load testing conn, wsh.hub, roomID, diff --git a/backend/internal/models/document.go b/backend/internal/models/document.go index 7fe6e10..e0c8029 100644 --- a/backend/internal/models/document.go +++ b/backend/internal/models/document.go @@ -14,28 +14,26 @@ const ( ) type Document struct { - ID uuid.UUID `json:"id"` - Name string `json:"name"` - Type DocumentType `json:"type"` - YjsState []byte `json:"-"` - OwnerID *uuid.UUID `json:"owner_id"` // NEW - Is_Public bool `json:"is_public"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uuid.UUID `json:"id"` + Name string `json:"name"` + Type DocumentType `json:"type"` + YjsState []byte `json:"-"` + OwnerID *uuid.UUID `json:"owner_id"` // NEW + Is_Public bool `json:"is_public"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } - type CreateDocumentRequest struct { - Name string `json:"name" binding:"required"` - Type DocumentType `json:"type" binding:"required"` - } + Name string `json:"name" binding:"required"` + Type DocumentType `json:"type" binding:"required"` +} - type UpdateStateRequest struct { - State []byte `json:"state" binding:"required"` - } - - type DocumentListResponse struct { - Documents []Document `json:"documents"` - Total int `json:"total"` - } +type UpdateStateRequest struct { + State []byte `json:"state" binding:"required"` +} +type DocumentListResponse struct { + Documents []Document `json:"documents"` + Total int `json:"total"` +} diff --git a/backend/internal/models/share.go b/backend/internal/models/share.go index 69637d7..effab6b 100644 --- a/backend/internal/models/share.go +++ b/backend/internal/models/share.go @@ -7,30 +7,30 @@ import ( ) type DocumentShare struct { - ID uuid.UUID `json:"id"` - DocumentID uuid.UUID `json:"document_id"` - UserID uuid.UUID `json:"user_id"` - Permission string `json:"permission"` // "view" or "edit" - CreatedAt time.Time `json:"created_at"` - CreatedBy *uuid.UUID `json:"created_by"` + ID uuid.UUID `json:"id"` + DocumentID uuid.UUID `json:"document_id"` + UserID uuid.UUID `json:"user_id"` + Permission string `json:"permission"` // "view" or "edit" + CreatedAt time.Time `json:"created_at"` + CreatedBy *uuid.UUID `json:"created_by"` } type CreateShareRequest struct { - UserEmail string `json:"user_email" binding:"required"` - Permission string `json:"permission" binding:"required,oneof=view edit"` + UserEmail string `json:"user_email" binding:"required"` + Permission string `json:"permission" binding:"required,oneof=view edit"` } type ShareListResponse struct { - Shares []DocumentShareWithUser `json:"shares"` + Shares []DocumentShareWithUser `json:"shares"` } type DocumentShareWithUser struct { - DocumentShare - User User `json:"user"` + DocumentShare + User User `json:"user"` } // PermissionResponse represents the user's permission level for a document type PermissionResponse struct { - Permission string `json:"permission"` // "view" or "edit" - Role string `json:"role"` // "owner", "editor", or "viewer" + Permission string `json:"permission"` // "view" or "edit" + Role string `json:"role"` // "owner", "editor", or "viewer" } diff --git a/backend/internal/models/user.go b/backend/internal/models/user.go index 3a1bbc6..2e2aa6b 100644 --- a/backend/internal/models/user.go +++ b/backend/internal/models/user.go @@ -7,42 +7,42 @@ import ( ) type User struct { - ID uuid.UUID `json:"id"` - Email string `json:"email"` - Name string `json:"name"` - AvatarURL *string `json:"avatar_url"` - Provider string `json:"provider"` - ProviderUserID string `json:"-"` // Don't expose - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - LastLoginAt *time.Time `json:"last_login_at"` + ID uuid.UUID `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + AvatarURL *string `json:"avatar_url"` + Provider string `json:"provider"` + ProviderUserID string `json:"-"` // Don't expose + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + LastLoginAt *time.Time `json:"last_login_at"` } type Session struct { - ID uuid.UUID `json:"id"` - UserID uuid.UUID `json:"user_id"` - TokenHash string `json:"-"` // SHA-256 hash of JWT - ExpiresAt time.Time `json:"expires_at"` - CreatedAt time.Time `json:"created_at"` - UserAgent *string `json:"user_agent"` - IPAddress *string `json:"ip_address"` + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + TokenHash string `json:"-"` // SHA-256 hash of JWT + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + UserAgent *string `json:"user_agent"` + IPAddress *string `json:"ip_address"` } type OAuthToken struct { - ID uuid.UUID `json:"id"` - UserID uuid.UUID `json:"user_id"` - Provider string `json:"provider"` - AccessToken string `json:"-"` // Don't expose - RefreshToken *string `json:"-"` - TokenType string `json:"token_type"` - ExpiresAt time.Time `json:"expires_at"` - Scope *string `json:"scope"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + Provider string `json:"provider"` + AccessToken string `json:"-"` // Don't expose + RefreshToken *string `json:"-"` + TokenType string `json:"token_type"` + ExpiresAt time.Time `json:"expires_at"` + Scope *string `json:"scope"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } // Response for /auth/me endpoint type UserResponse struct { - User *User `json:"user"` - Token string `json:"token,omitempty"` + User *User `json:"user"` + Token string `json:"token,omitempty"` } diff --git a/backend/internal/models/version.go b/backend/internal/models/version.go index 4753ff7..d0ba924 100644 --- a/backend/internal/models/version.go +++ b/backend/internal/models/version.go @@ -10,8 +10,8 @@ import ( type DocumentVersion struct { ID uuid.UUID `json:"id"` DocumentID uuid.UUID `json:"document_id"` - YjsSnapshot []byte `json:"-"` // Omit from JSON (binary) - TextPreview *string `json:"text_preview"` // Full plain text + YjsSnapshot []byte `json:"-"` // Omit from JSON (binary) + TextPreview *string `json:"text_preview"` // Full plain text VersionNumber int `json:"version_number"` CreatedBy *uuid.UUID `json:"created_by"` VersionLabel *string `json:"version_label"` diff --git a/backend/internal/store/postgres.go b/backend/internal/store/postgres.go index 01974e6..fe46d35 100644 --- a/backend/internal/store/postgres.go +++ b/backend/internal/store/postgres.go @@ -13,33 +13,33 @@ import ( // Store interface defines all database operations type Store interface { - // Document operations - CreateDocument(name string, docType models.DocumentType) (*models.Document, error) - CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS - GetDocument(id uuid.UUID) (*models.Document, error) - ListDocuments() ([]models.Document, error) - ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS - UpdateDocumentState(id uuid.UUID, state []byte) error - DeleteDocument(id uuid.UUID) error - - // User operations - UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) - GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) - GetUserByEmail(ctx context.Context, email string) (*models.User, error) - - // Session operations - CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) - GetSessionByToken(ctx context.Context, token string) (*models.Session, error) - DeleteSession(ctx context.Context, token string) error - CleanupExpiredSessions(ctx context.Context) error - - // Share operations - CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) - ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) - DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error - CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) - CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) - IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) + // Document operations + CreateDocument(name string, docType models.DocumentType) (*models.Document, error) + CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS + GetDocument(id uuid.UUID) (*models.Document, error) + ListDocuments() ([]models.Document, error) + ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS + UpdateDocumentState(id uuid.UUID, state []byte) error + DeleteDocument(id uuid.UUID) error + + // User operations + UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) + GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) + GetUserByEmail(ctx context.Context, email string) (*models.User, error) + + // Session operations + CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) + GetSessionByToken(ctx context.Context, token string) (*models.Session, error) + DeleteSession(ctx context.Context, token string) error + CleanupExpiredSessions(ctx context.Context) error + + // Share operations + CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) + ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) + DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error + CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) + CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) + IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error) RevokeShareToken(ctx context.Context, documentID uuid.UUID) error @@ -53,13 +53,11 @@ type Store interface { GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error) GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error) - - Close() error + Close() error } - type PostgresStore struct { - db *sql.DB + db *sql.DB } func NewPostgresStore(databaseUrl string) (*PostgresStore, error) { @@ -68,17 +66,17 @@ func NewPostgresStore(databaseUrl string) (*PostgresStore, error) { return nil, error } if err := db.Ping(); err != nil { - return nil, fmt.Errorf("failed to ping database: %w", err) - } + return nil, fmt.Errorf("failed to ping database: %w", err) + } db.SetMaxOpenConns(25) - db.SetMaxIdleConns(5) - db.SetConnMaxLifetime(5 * time.Minute) + db.SetMaxIdleConns(5) + db.SetConnMaxLifetime(5 * time.Minute) return &PostgresStore{db: db}, nil } func (s *PostgresStore) Close() error { - return s.db.Close() - } + return s.db.Close() +} func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) { doc := &models.Document{ @@ -95,12 +93,11 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) RETURNING id, name, type, created_at, updated_at ` err := s.db.QueryRow(query, - doc.ID, - doc.Name, - doc.Type, - doc.CreatedAt, + doc.ID, + doc.Name, + doc.Type, + doc.CreatedAt, doc.UpdatedAt, - ).Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt) if err != nil { @@ -109,164 +106,163 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) return doc, nil } - // GetDocument retrieves a document by ID - func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) { - doc := &models.Document{} +// GetDocument retrieves a document by ID +func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) { + doc := &models.Document{} - query := ` + query := ` SELECT id, name, type, yjs_state, owner_id, is_public, created_at, updated_at FROM documents WHERE id = $1 ` - err := s.db.QueryRow(query, id).Scan( - &doc.ID, - &doc.Name, - &doc.Type, - &doc.YjsState, - &doc.OwnerID, - &doc.Is_Public, - &doc.CreatedAt, - &doc.UpdatedAt, - ) + err := s.db.QueryRow(query, id).Scan( + &doc.ID, + &doc.Name, + &doc.Type, + &doc.YjsState, + &doc.OwnerID, + &doc.Is_Public, + &doc.CreatedAt, + &doc.UpdatedAt, + ) - if err == sql.ErrNoRows { - return nil, fmt.Errorf("document not found") - } - if err != nil { - return nil, fmt.Errorf("failed to get document: %w", err) - } + if err == sql.ErrNoRows { + return nil, fmt.Errorf("document not found") + } + if err != nil { + return nil, fmt.Errorf("failed to get document: %w", err) + } - return doc, nil - } + return doc, nil +} - - // ListDocuments retrieves all documents - func (s *PostgresStore) ListDocuments() ([]models.Document, error) { - query := ` +// ListDocuments retrieves all documents +func (s *PostgresStore) ListDocuments() ([]models.Document, error) { + query := ` SELECT id, name, type, created_at, updated_at FROM documents ORDER BY created_at DESC ` - rows, err := s.db.Query(query) - if err != nil { - return nil, fmt.Errorf("failed to list documents: %w", err) - } - defer rows.Close() + rows, err := s.db.Query(query) + if err != nil { + return nil, fmt.Errorf("failed to list documents: %w", err) + } + defer rows.Close() - var documents []models.Document - for rows.Next() { - var doc models.Document - err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt) - if err != nil { - return nil, fmt.Errorf("failed to scan document: %w", err) - } - documents = append(documents, doc) - } + var documents []models.Document + for rows.Next() { + var doc models.Document + err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt) + if err != nil { + return nil, fmt.Errorf("failed to scan document: %w", err) + } + documents = append(documents, doc) + } - return documents, nil - } + return documents, nil +} - func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error { - query := ` +func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error { + query := ` UPDATE documents SET yjs_state = $1, updated_at = $2 WHERE id = $3 ` - result, err := s.db.Exec(query, state, time.Now(), id) - if err != nil { - return fmt.Errorf("failed to update document state: %w", err) - } + result, err := s.db.Exec(query, state, time.Now(), id) + if err != nil { + return fmt.Errorf("failed to update document state: %w", err) + } - rowsAffected, err := result.RowsAffected() - if err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) - } + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } - if rowsAffected == 0 { - return fmt.Errorf("document not found") - } + if rowsAffected == 0 { + return fmt.Errorf("document not found") + } - return nil - } + return nil +} - func (s *PostgresStore) DeleteDocument(id uuid.UUID) error { - query := `DELETE FROM documents WHERE id = $1` +func (s *PostgresStore) DeleteDocument(id uuid.UUID) error { + query := `DELETE FROM documents WHERE id = $1` - result, err := s.db.Exec(query, id) - if err != nil { - return fmt.Errorf("failed to delete document: %w", err) - } + result, err := s.db.Exec(query, id) + if err != nil { + return fmt.Errorf("failed to delete document: %w", err) + } - rowsAffected, err := result.RowsAffected() - if err != nil { - return fmt.Errorf("failed to get rows affected: %w", err) - } + rowsAffected, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } - if rowsAffected == 0 { - return fmt.Errorf("document not found") - } + if rowsAffected == 0 { + return fmt.Errorf("document not found") + } - return nil - } + return nil +} // CreateDocumentWithOwner creates a new document with owner func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) { - // 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错) - if docType == "" { - docType = models.DocumentTypeEditor // Default to editor instead of invalid "text" - } + // 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错) + if docType == "" { + docType = models.DocumentTypeEditor // Default to editor instead of invalid "text" + } - // Validate that docType is one of the allowed values - if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban { - return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType) - } + // Validate that docType is one of the allowed values + if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban { + return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType) + } - doc := &models.Document{ - ID: uuid.New(), - Name: name, - Type: docType, - YjsState: []byte{}, // 这里初始化了空字节 - OwnerID: ownerID, - Is_Public: false, // 显式设置默认值 - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } + doc := &models.Document{ + ID: uuid.New(), + Name: name, + Type: docType, + YjsState: []byte{}, // 这里初始化了空字节 + OwnerID: ownerID, + Is_Public: false, // 显式设置默认值 + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } - // 2. 补全了 yjs_state 和 is_public - query := ` + // 2. 补全了 yjs_state 和 is_public + query := ` INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at ` - - // 3. Scan 的时候也要对应加上 - err := s.db.QueryRow(query, - doc.ID, - doc.Name, - doc.Type, - doc.OwnerID, - doc.YjsState, // $5 - doc.Is_Public, // $6 - doc.CreatedAt, - doc.UpdatedAt, - ).Scan( - &doc.ID, - &doc.Name, - &doc.Type, - &doc.OwnerID, - &doc.YjsState, // Scan 回来 - &doc.Is_Public, // Scan 回来 - &doc.CreatedAt, - &doc.UpdatedAt, - ) - if err != nil { - return nil, fmt.Errorf("failed to create document: %w", err) - } - return doc, nil + // 3. Scan 的时候也要对应加上 + err := s.db.QueryRow(query, + doc.ID, + doc.Name, + doc.Type, + doc.OwnerID, + doc.YjsState, // $5 + doc.Is_Public, // $6 + doc.CreatedAt, + doc.UpdatedAt, + ).Scan( + &doc.ID, + &doc.Name, + &doc.Type, + &doc.OwnerID, + &doc.YjsState, // Scan 回来 + &doc.Is_Public, // Scan 回来 + &doc.CreatedAt, + &doc.UpdatedAt, + ) + + if err != nil { + return nil, fmt.Errorf("failed to create document: %w", err) + } + return doc, nil } // ListUserDocuments lists documents owned by or shared with a user diff --git a/backend/internal/store/session.go b/backend/internal/store/session.go index 4e21686..b1852ce 100644 --- a/backend/internal/store/session.go +++ b/backend/internal/store/session.go @@ -12,77 +12,77 @@ import ( // CreateSession creates a new session func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) { - // Hash the token before storing - hash := sha256.Sum256([]byte(token)) - tokenHash := hex.EncodeToString(hash[:]) + // Hash the token before storing + hash := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(hash[:]) - // 【修改点 1】: 在 SQL 里显式加上 id 字段 - // 注意:$1 变成了 id,后面的参数序号全部要顺延 (+1) - query := ` + // 【修改点 1】: 在 SQL 里显式加上 id 字段 + // 注意:$1 变成了 id,后面的参数序号全部要顺延 (+1) + query := ` INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address) VALUES ($1, $2, $3, $4, $5, $6) RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address ` - var session models.Session - // 【修改点 2】: 在参数列表的最前面加上 sessionID - // 现在的对应关系: - // $1 -> sessionID - // $2 -> userID - // $3 -> tokenHash - // ... - err := s.db.QueryRowContext(ctx, query, - sessionID, // <--- 这里!把它传进去! - userID, - tokenHash, - expiresAt, - userAgent, - ipAddress, - ).Scan( - &session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt, - &session.CreatedAt, &session.UserAgent, &session.IPAddress, - ) - if err != nil { - return nil, err - } + var session models.Session + // 【修改点 2】: 在参数列表的最前面加上 sessionID + // 现在的对应关系: + // $1 -> sessionID + // $2 -> userID + // $3 -> tokenHash + // ... + err := s.db.QueryRowContext(ctx, query, + sessionID, // <--- 这里!把它传进去! + userID, + tokenHash, + expiresAt, + userAgent, + ipAddress, + ).Scan( + &session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt, + &session.CreatedAt, &session.UserAgent, &session.IPAddress, + ) + if err != nil { + return nil, err + } - return &session, nil + return &session, nil } // GetSessionByToken retrieves session by JWT token func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) { - hash := sha256.Sum256([]byte(token)) - tokenHash := hex.EncodeToString(hash[:]) + hash := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(hash[:]) - query := ` + query := ` SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address FROM sessions WHERE token_hash = $1 AND expires_at > NOW() ` - var session models.Session - err := s.db.QueryRowContext(ctx, query, tokenHash).Scan( - &session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt, - &session.CreatedAt, &session.UserAgent, &session.IPAddress, - ) - if err != nil { - return nil, err - } + var session models.Session + err := s.db.QueryRowContext(ctx, query, tokenHash).Scan( + &session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt, + &session.CreatedAt, &session.UserAgent, &session.IPAddress, + ) + if err != nil { + return nil, err + } - return &session, nil + return &session, nil } // DeleteSession deletes a session (logout) func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error { - hash := sha256.Sum256([]byte(token)) - tokenHash := hex.EncodeToString(hash[:]) + hash := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(hash[:]) - _, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash) - return err + _, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash) + return err } // CleanupExpiredSessions removes expired sessions func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error { - _, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()") - return err + _, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()") + return err } diff --git a/backend/internal/store/share.go b/backend/internal/store/share.go index c921425..2346f95 100644 --- a/backend/internal/store/share.go +++ b/backend/internal/store/share.go @@ -14,34 +14,34 @@ import ( // CreateDocumentShare creates a new share or updates existing one // Returns the share and a boolean indicating if it was newly created (true) or updated (false) func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) { - // First check if share already exists - var existingID uuid.UUID - checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2` - err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID) - isNewShare := err != nil // If error (not found), it's a new share + // First check if share already exists + var existingID uuid.UUID + checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2` + err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID) + isNewShare := err != nil // If error (not found), it's a new share - query := ` + query := ` INSERT INTO document_shares (document_id, user_id, permission, created_by) VALUES ($1, $2, $3, $4) ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission RETURNING id, document_id, user_id, permission, created_at, created_by ` - var share models.DocumentShare - err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan( - &share.ID, &share.DocumentID, &share.UserID, &share.Permission, - &share.CreatedAt, &share.CreatedBy, - ) - if err != nil { - return nil, false, err - } + var share models.DocumentShare + err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan( + &share.ID, &share.DocumentID, &share.UserID, &share.Permission, + &share.CreatedAt, &share.CreatedBy, + ) + if err != nil { + return nil, false, err + } - return &share, isNewShare, nil + return &share, isNewShare, nil } // ListDocumentShares lists all shares for a document func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) { - query := ` + query := ` SELECT ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by, u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at @@ -51,38 +51,38 @@ func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid. ORDER BY ds.created_at DESC ` - rows, err := s.db.QueryContext(ctx, query, documentID) - if err != nil { - return nil, err - } - defer rows.Close() + rows, err := s.db.QueryContext(ctx, query, documentID) + if err != nil { + return nil, err + } + defer rows.Close() - var shares []models.DocumentShareWithUser - for rows.Next() { - var share models.DocumentShareWithUser - err := rows.Scan( - &share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy, - &share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider, - &share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt, - ) - if err != nil { - return nil, err - } - shares = append(shares, share) - } + var shares []models.DocumentShareWithUser + for rows.Next() { + var share models.DocumentShareWithUser + err := rows.Scan( + &share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy, + &share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider, + &share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt, + ) + if err != nil { + return nil, err + } + shares = append(shares, share) + } - return shares, nil + return shares, nil } // DeleteDocumentShare deletes a share func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error { - _, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID) - return err + _, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID) + return err } // CanViewDocument checks if user can view document (owner OR has any share) func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) { - query := ` + query := ` SELECT EXISTS( SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2 UNION @@ -90,14 +90,14 @@ func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID ) ` - var canView bool - err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView) - return canView, err + var canView bool + err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView) + return canView, err } // CanEditDocument checks if user can edit document (owner OR has edit share) func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) { - query := ` + query := ` SELECT EXISTS( SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2 UNION @@ -105,21 +105,21 @@ func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID ) ` - var canEdit bool - err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit) - return canEdit, err + var canEdit bool + err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit) + return canEdit, err } // IsDocumentOwner checks if user is the owner func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) { - query := `SELECT owner_id = $2 FROM documents WHERE id = $1` + query := `SELECT owner_id = $2 FROM documents WHERE id = $1` - var isOwner bool - err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner) - if err == sql.ErrNoRows { - return false, nil - } - return isOwner, err + var isOwner bool + err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner) + if err == sql.ErrNoRows { + return false, nil + } + return isOwner, err } func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) { // Generate random 32-byte token @@ -239,4 +239,4 @@ func (s *PostgresStore) GetShareLinkPermission(ctx context.Context, documentID u } return permission, nil -} \ No newline at end of file +} diff --git a/backend/internal/store/user.go b/backend/internal/store/user.go index 18e7fd4..79fd964 100644 --- a/backend/internal/store/user.go +++ b/backend/internal/store/user.go @@ -11,7 +11,7 @@ import ( // UpsertUser creates or updates user from OAuth profile func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) { - query := ` + query := ` INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at) VALUES ($1, $2, $3, $4, $5, NOW()) ON CONFLICT (provider, provider_user_id) @@ -24,59 +24,59 @@ func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at ` - var user models.User - err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan( - &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, - &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, - ) - if err != nil { - return nil, err - } - fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email) + var user models.User + err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan( + &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, + &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, + ) + if err != nil { + return nil, err + } + fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email) - return &user, nil + return &user, nil } // GetUserByID retrieves user by ID func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) { - query := ` + query := ` SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at FROM users WHERE id = $1 ` - var user models.User - err := s.db.QueryRowContext(ctx, query, userID).Scan( - &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, - &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, - ) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, err - } + var user models.User + err := s.db.QueryRowContext(ctx, query, userID).Scan( + &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, + &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } - return &user, nil + return &user, nil } // GetUserByEmail retrieves user by email func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { - query := ` + query := ` SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at FROM users WHERE email = $1 ` - var user models.User - err := s.db.QueryRowContext(ctx, query, email).Scan( - &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, - &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, - ) - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, err - } + var user models.User + err := s.db.QueryRowContext(ctx, query, email).Scan( + &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, + &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } - return &user, nil + return &user, nil }