Refactor and improve code consistency across multiple files

- Enhanced SQL queries in `session.go` and `share.go` for clarity and consistency.
- Updated comments for better understanding and maintenance.
- Ensured consistent error handling and return statements across various methods.
This commit is contained in:
M1ngdaXie
2026-02-04 22:01:47 -08:00
parent 0f4cff89a2
commit c84cbafb2c
18 changed files with 629 additions and 631 deletions

View File

@@ -60,4 +60,4 @@ func ValidateJWT(tokenString, secret string) (*UserClaims, error) {
}
return nil, errors.New("invalid token claims")
}
}

View File

@@ -18,142 +18,142 @@ const ContextUserIDKey = "user_id"
// AuthMiddleware provides auth middleware
type AuthMiddleware struct {
store store.Store
jwtSecret string
store store.Store
jwtSecret string
}
// NewAuthMiddleware creates a new auth middleware
func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware {
return &AuthMiddleware{
store: store,
jwtSecret: jwtSecret,
}
return &AuthMiddleware{
store: store,
jwtSecret: jwtSecret,
}
}
// RequireAuth middleware requires valid authentication
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
fmt.Println("🔒 RequireAuth: Starting authentication check")
return func(c *gin.Context) {
fmt.Println("🔒 RequireAuth: Starting authentication check")
user, claims, err := m.getUserFromToken(c)
user, claims, err := m.getUserFromToken(c)
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
if claims != nil {
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
}
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
if claims != nil {
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
}
if err != nil || user == nil {
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
if err != nil || user == nil {
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
// Note: Name and Email might be empty for old JWT tokens
if claims.Name == "" || claims.Email == "" {
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
}
// Note: Name and Email might be empty for old JWT tokens
if claims.Name == "" || claims.Email == "" {
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
}
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
c.Set(ContextUserIDKey, user)
c.Set("user_email", claims.Email)
c.Set("user_name", claims.Name)
if claims.AvatarURL != nil {
c.Set("avatar_url", *claims.AvatarURL)
}
c.Next()
}
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
c.Set(ContextUserIDKey, user)
c.Set("user_email", claims.Email)
c.Set("user_name", claims.Name)
if claims.AvatarURL != nil {
c.Set("avatar_url", *claims.AvatarURL)
}
c.Next()
}
}
// OptionalAuth middleware sets user if authenticated, but doesn't require it
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
return func(c *gin.Context) {
user, claims, _ := m.getUserFromToken(c)
if user != nil {
c.Set(string(UserContextKey), user)
c.Set(ContextUserIDKey, user)
if claims != nil {
c.Set("user_email", claims.Email)
c.Set("user_name", claims.Name)
if claims.AvatarURL != nil {
c.Set("avatar_url", *claims.AvatarURL)
}
}
}
c.Next()
}
return func(c *gin.Context) {
user, claims, _ := m.getUserFromToken(c)
if user != nil {
c.Set(string(UserContextKey), user)
c.Set(ContextUserIDKey, user)
if claims != nil {
c.Set("user_email", claims.Email)
c.Set("user_name", claims.Name)
if claims.AvatarURL != nil {
c.Set("avatar_url", *claims.AvatarURL)
}
}
}
c.Next()
}
}
// getUserFromToken parses the JWT and returns the UserID and the full Claims (for name/email)
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
authHeader := c.GetHeader("Authorization")
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
authHeader := c.GetHeader("Authorization")
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
if authHeader == "" {
fmt.Println("⚠️ getUserFromToken: No Authorization header")
return nil, nil, nil
}
if authHeader == "" {
fmt.Println("⚠️ getUserFromToken: No Authorization header")
return nil, nil, nil
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
return nil, nil, nil
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
return nil, nil, nil
}
tokenString := parts[1]
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
tokenString := parts[1]
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
// 必须要验证签名算法是 HMAC (HS256)
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(m.jwtSecret), nil
})
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
// 必须要验证签名算法是 HMAC (HS256)
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(m.jwtSecret), nil
})
if err != nil {
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
return nil, nil, err
}
if err != nil {
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
return nil, nil, err
}
// 2. 验证 Token 有效性并提取 Claims
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
// 3. 把 String 类型的 Subject 转回 UUID
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
userID, err := uuid.Parse(claims.Subject)
if err != nil {
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
return nil, nil, fmt.Errorf("invalid user ID in token")
}
// 2. 验证 Token 有效性并提取 Claims
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
// 3. 把 String 类型的 Subject 转回 UUID
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
userID, err := uuid.Parse(claims.Subject)
if err != nil {
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
return nil, nil, fmt.Errorf("invalid user ID in token")
}
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
// 这一步完全没有查数据库,速度极快
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
return &userID, claims, nil
}
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
// 这一步完全没有查数据库,速度极快
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
return &userID, claims, nil
}
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
return nil, nil, fmt.Errorf("invalid token claims")
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
return nil, nil, fmt.Errorf("invalid token claims")
}
// GetUserFromContext extracts user ID from context
func GetUserFromContext(c *gin.Context) *uuid.UUID {
// 修正点:使用和存入时完全一样的 Key
val, exists := c.Get(ContextUserIDKey)
fmt.Println("within getFromContext the id is ... ")
fmt.Println(val);
if !exists {
return nil
}
// 修正点:使用和存入时完全一样的 Key
val, exists := c.Get(ContextUserIDKey)
fmt.Println("within getFromContext the id is ... ")
fmt.Println(val)
if !exists {
return nil
}
// 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型)
uid, ok := val.(*uuid.UUID)
if !ok {
return nil
}
// 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型)
uid, ok := val.(*uuid.UUID)
if !ok {
return nil
}
return uid
return uid
}
// ValidateToken validates a JWT token and returns user ID, name, and avatar URL from JWT claims

View File

@@ -8,25 +8,25 @@ import (
// GetGoogleOAuthConfig returns Google OAuth2 config
func GetGoogleOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
return &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
}
return &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
}
}
// GetGitHubOAuthConfig returns GitHub OAuth2 config
func GetGitHubOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
return &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{"user:email"},
Endpoint: github.Endpoint,
}
return &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{"user:email"},
Endpoint: github.Endpoint,
}
}

View File

@@ -64,10 +64,10 @@ func (h *AuthHandler) GoogleLogin(c *gin.Context) {
// GoogleCallback handles Google OAuth callback
func (h *AuthHandler) GoogleCallback(c *gin.Context) {
oauthState, err := c.Cookie("oauthstate")
if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return
}
if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return
}
log.Println("Google callback state:", c.Query("state"))
// Exchange code for token
token, err := h.googleConfig.Exchange(c.Request.Context(), c.Query("code"))
@@ -94,11 +94,11 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
Name string `json:"name"`
Picture string `json:"picture"`
}
if err := json.Unmarshal(data, &userInfo); err != nil {
log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"})
return
log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"})
return
}
log.Println("Google user info:", userInfo)
// Upsert user in database
@@ -118,10 +118,10 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
// Create session and JWT
jwt, err := h.createSessionAndJWT(c, user)
if err != nil {
fmt.Printf("❌ DATABASE ERROR: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("CreateSession Error: %v", err),
})
fmt.Printf("❌ DATABASE ERROR: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("CreateSession Error: %v", err),
})
return
}
@@ -140,10 +140,10 @@ func (h *AuthHandler) GithubLogin(c *gin.Context) {
// GithubCallback handles GitHub OAuth callback
func (h *AuthHandler) GithubCallback(c *gin.Context) {
oauthState, err := c.Cookie("oauthstate")
if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return
}
if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return
}
log.Println("Github callback state:", c.Query("state"))
code := c.Query("code")
if code == "" {
@@ -160,7 +160,7 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
// Get user info from GitHub
client := h.githubConfig.Client(c.Request.Context(), token)
// Get user profile
resp, err := client.Get("https://api.github.com/user")
if err != nil {
@@ -178,10 +178,10 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
AvatarURL string `json:"avatar_url"`
}
if err := json.Unmarshal(data, &userInfo); err != nil {
log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"})
return
}
log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"})
return
}
// If email is not public, fetch it separately
if userInfo.Email == "" {
@@ -315,10 +315,10 @@ func (h *AuthHandler) generateStateOauthCookie(w http.ResponseWriter) string {
Name: "oauthstate",
Value: state,
Expires: time.Now().Add(10 * time.Minute),
HttpOnly: true, // Prevents JavaScript access (XSS protection)
Secure: h.cfg.SecureCookie, // true in production, false for localhost
SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects
Path: "/", // Ensures cookie is sent to all backend paths
HttpOnly: true, // Prevents JavaScript access (XSS protection)
Secure: h.cfg.SecureCookie, // true in production, false for localhost
SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects
Path: "/", // Ensures cookie is sent to all backend paths
}
http.SetCookie(w, &cookie)

View File

@@ -11,16 +11,15 @@ import (
"github.com/google/uuid"
)
type DocumentHandler struct {
store *store.PostgresStore
}
type DocumentHandler struct {
store *store.PostgresStore
}
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
return &DocumentHandler{store: s}
}
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
return &DocumentHandler{store: s}
}
// CreateDocument creates a new document (requires auth)
// CreateDocument creates a new document (requires auth)
func (h *DocumentHandler) CreateDocument(c *gin.Context) {
userID := auth.GetUserFromContext(c)
if userID == nil {
@@ -44,7 +43,7 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
c.JSON(http.StatusCreated, doc)
}
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
userID := auth.GetUserFromContext(c)
fmt.Println("Getting userId, which is : ")
fmt.Println(userID)
@@ -66,8 +65,7 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
})
}
func (h *DocumentHandler) GetDocument(c *gin.Context) {
func (h *DocumentHandler) GetDocument(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
respondBadRequest(c, "Invalid document ID format")
@@ -104,8 +102,9 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
c.JSON(http.StatusOK, doc)
}
// GetDocumentState returns the Yjs state for a document
// GetDocumentState retrieves document state (requires view permission)
// GetDocumentState returns the Yjs state for a document
// GetDocumentState retrieves document state (requires view permission)
func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
@@ -143,7 +142,7 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
c.Data(http.StatusOK, "application/octet-stream", state)
}
// UpdateDocumentState updates document state (requires edit permission)
// UpdateDocumentState updates document state (requires edit permission)
func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
@@ -195,7 +194,7 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
}
// DeleteDocument deletes a document (owner only)
// DeleteDocument deletes a document (owner only)
func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
@@ -237,107 +236,109 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
// GetDocumentPermission returns the user's permission level for a document
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id"))
if err != nil {
respondBadRequest(c, "Invalid document ID format")
return
}
documentID, err := uuid.Parse(c.Param("id"))
if err != nil {
respondBadRequest(c, "Invalid document ID format")
return
}
// 1. 先检查文档是否存在 (Good practice)
_, err = h.store.GetDocument(documentID)
if err != nil {
respondNotFound(c, "document")
return
}
// 1. 先检查文档是否存在 (Good practice)
_, err = h.store.GetDocument(documentID)
if err != nil {
respondNotFound(c, "document")
return
}
userID := auth.GetUserFromContext(c)
shareToken := c.Query("share")
userID := auth.GetUserFromContext(c)
shareToken := c.Query("share")
// 定义两个临时变量,用来存两边的结果
var userPerm string // 存 document_shares 的结果
var tokenPerm string // 存 share_token 的结果
// 定义两个临时变量,用来存两边的结果
var userPerm string // 存 document_shares 的结果
var tokenPerm string // 存 share_token 的结果
// ====================================================
// 步骤 A: 检查个人权限 (Base Permission)
// ====================================================
if userID != nil {
perm, err := h.store.GetUserPermission(c.Request.Context(), documentID, *userID)
if err != nil {
respondInternalError(c, "Failed to get user permission", err)
return
}
userPerm = perm
// ⚠️ 注意:如果 perm 是空,这里不报错!继续往下走!
}
// ====================================================
// 步骤 A: 检查个人权限 (Base Permission)
// ====================================================
if userID != nil {
perm, err := h.store.GetUserPermission(c.Request.Context(), documentID, *userID)
if err != nil {
respondInternalError(c, "Failed to get user permission", err)
return
}
userPerm = perm
// ⚠️ 注意:如果 perm 是空,这里不报错!继续往下走!
}
// ====================================================
// 步骤 B: 检查 Token 权限 (Upgrade Permission)
// ====================================================
if shareToken != "" {
// 先验证 Token 是否有效
valid, err := h.store.ValidateShareToken(c.Request.Context(), documentID, shareToken)
if err != nil {
respondInternalError(c, "Failed to validate token", err)
return
}
// 只有 Token 有效才去取权限
if valid {
p, err := h.store.GetShareLinkPermission(c.Request.Context(), documentID)
if err != nil {
respondInternalError(c, "Failed to get token permission", err)
return
}
tokenPerm = p
// 处理数据库老数据的 fallback
if tokenPerm == "" { tokenPerm = "view" }
}
}
// ====================================================
// 步骤 B: 检查 Token 权限 (Upgrade Permission)
// ====================================================
if shareToken != "" {
// 先验证 Token 是否有效
valid, err := h.store.ValidateShareToken(c.Request.Context(), documentID, shareToken)
if err != nil {
respondInternalError(c, "Failed to validate token", err)
return
}
// ====================================================
// 步骤 C: ⚡️ 权限合并与计算 (The Brain)
// ====================================================
finalPermission := ""
role := "viewer" // 默认角色
// 只有 Token 有效才去取权限
if valid {
p, err := h.store.GetShareLinkPermission(c.Request.Context(), documentID)
if err != nil {
respondInternalError(c, "Failed to get token permission", err)
return
}
tokenPerm = p
// 处理数据库老数据的 fallback
if tokenPerm == "" {
tokenPerm = "view"
}
}
}
// 1. 如果是 Owner无敌直接返回
if userPerm == "owner" {
finalPermission = "edit"
role = "owner"
// 直接返回,不用看 Token 了
c.JSON(http.StatusOK, models.PermissionResponse{
Permission: finalPermission,
Role: role,
})
return
}
// ====================================================
// 步骤 C: ⚡️ 权限合并与计算 (The Brain)
// ====================================================
// 2. 比较 User 和 Token取最大值
// 逻辑:只要任意一边给了 "edit",那就是 "edit"
if userPerm == "edit" || tokenPerm == "edit" {
finalPermission = "edit"
role = "editor"
} else if userPerm == "view" || tokenPerm == "view" {
finalPermission = "view"
role = "viewer"
}
finalPermission := ""
role := "viewer" // 默认角色
// ====================================================
// 步骤 D: 最终判决
// ====================================================
if finalPermission == "" {
// 既没个人权限Token 也不对(或者没 Token
if userID == nil {
respondUnauthorized(c, "Authentication required") // 没登录且没Token
} else {
respondForbidden(c, "You don't have permission") // 登录了但没权限
}
return
}
// 1. 如果是 Owner无敌直接返回
if userPerm == "owner" {
finalPermission = "edit"
role = "owner"
// 直接返回,不用看 Token
c.JSON(http.StatusOK, models.PermissionResponse{
Permission: finalPermission,
Role: role,
})
return
}
c.JSON(http.StatusOK, models.PermissionResponse{
Permission: finalPermission,
Role: role,
})
}
// 2. 比较 User 和 Token取最大值
// 逻辑:只要任意一边给了 "edit",那就是 "edit"
if userPerm == "edit" || tokenPerm == "edit" {
finalPermission = "edit"
role = "editor"
} else if userPerm == "view" || tokenPerm == "view" {
finalPermission = "view"
role = "viewer"
}
// ====================================================
// 步骤 D: 最终判决
// ====================================================
if finalPermission == "" {
// 既没个人权限Token 也不对(或者没 Token
if userID == nil {
respondUnauthorized(c, "Authentication required") // 没登录且没Token
} else {
respondForbidden(c, "You don't have permission") // 登录了但没权限
}
return
}
c.JSON(http.StatusOK, models.PermissionResponse{
Permission: finalPermission,
Role: role,
})
}

View File

@@ -91,5 +91,5 @@ func respondInternalError(c *gin.Context, message string, err error) {
respondWithError(c, http.StatusInternalServerError, "internal_error", message)
}
func respondInvalidID(c *gin.Context, message string) {
respondWithError(c, http.StatusBadRequest, "invalid_id", message)
respondWithError(c, http.StatusBadRequest, "invalid_id", message)
}

View File

@@ -153,6 +153,7 @@ func (h *ShareHandler) DeleteShare(c *gin.Context) {
c.Status(204)
}
// CreateShareLink generates a public share link
func (h *ShareHandler) CreateShareLink(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id"))
@@ -290,4 +291,4 @@ func (h *ShareHandler) RevokeShareLink(c *gin.Context) {
// c.JSON(http.StatusOK, gin.H{"message": "Share link revoked successfully"})
c.Status(204)
}
}

View File

@@ -70,7 +70,7 @@ func TestShareHandlerSuite(t *testing.T) {
func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() {
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "view",
}
@@ -87,7 +87,7 @@ func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() {
func (s *ShareHandlerSuite) TestCreateShare_EditPermission() {
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "edit",
}
@@ -104,7 +104,7 @@ func (s *ShareHandlerSuite) TestCreateShare_EditPermission() {
func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() {
body := map[string]interface{}{
"user_email": "charlie@test.com",
"user_email": "charlie@test.com",
"permission": "view",
}
@@ -118,7 +118,7 @@ func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() {
func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() {
body := map[string]interface{}{
"user_email": "nonexistent@test.com",
"user_email": "nonexistent@test.com",
"permission": "view",
}
@@ -131,7 +131,7 @@ func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() {
func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() {
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "admin", // Invalid permission
}
@@ -145,7 +145,7 @@ func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() {
func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() {
// Create initial share with view permission
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "view",
}
@@ -169,7 +169,7 @@ func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() {
func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() {
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "view",
}
@@ -182,7 +182,7 @@ func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() {
func (s *ShareHandlerSuite) TestCreateShare_InvalidDocumentID() {
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "view",
}
@@ -206,8 +206,8 @@ func (s *ShareHandlerSuite) TestListShares_OwnerSeesAll() {
s.assertSuccessResponse(w, http.StatusOK)
var response models.ShareListResponse
s.parseJSONResponse(w, &response)
shares := response.Shares
s.parseJSONResponse(w, &response)
shares := response.Shares
s.GreaterOrEqual(len(shares), 1, "Should have at least one share")
}
@@ -229,8 +229,8 @@ func (s *ShareHandlerSuite) TestListShares_EmptyList() {
s.assertSuccessResponse(w, http.StatusOK)
var response models.ShareListResponse
s.parseJSONResponse(w, &response)
shares := response.Shares
s.parseJSONResponse(w, &response)
shares := response.Shares
s.Equal(0, len(shares), "Should have no shares")
}
@@ -243,8 +243,8 @@ func (s *ShareHandlerSuite) TestListShares_IncludesUserDetails() {
s.assertSuccessResponse(w, http.StatusOK)
var response models.ShareListResponse
s.parseJSONResponse(w, &response)
shares := response.Shares
s.parseJSONResponse(w, &response)
shares := response.Shares
if len(shares) > 0 {
share := shares[0]
@@ -266,7 +266,7 @@ func (s *ShareHandlerSuite) TestListShares_OrderedByCreatedAt() {
users := []string{"bob@test.com", "charlie@test.com"}
for _, email := range users {
body := map[string]interface{}{
"user_email": email,
"user_email": email,
"permission": "view",
}
w, httpReq, err := s.makeAuthRequest("POST", fmt.Sprintf("/api/documents/%s/shares", s.testData.AlicePrivateDoc), body, s.testData.AliceID)

View File

@@ -48,10 +48,10 @@ func (wsh *WebSocketHandler) HandleWebSocketLoadTest(c *gin.Context) {
clientID := uuid.New().String()
client := hub.NewClient(
clientID,
nil, // userID - nil for anonymous
userName, // userName
nil, // userAvatar
"edit", // permission - full access for load testing
nil, // userID - nil for anonymous
userName, // userName
nil, // userAvatar
"edit", // permission - full access for load testing
conn,
wsh.hub,
roomID,

View File

@@ -14,28 +14,26 @@ const (
)
type Document struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Type DocumentType `json:"type"`
YjsState []byte `json:"-"`
OwnerID *uuid.UUID `json:"owner_id"` // NEW
Is_Public bool `json:"is_public"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Type DocumentType `json:"type"`
YjsState []byte `json:"-"`
OwnerID *uuid.UUID `json:"owner_id"` // NEW
Is_Public bool `json:"is_public"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type CreateDocumentRequest struct {
Name string `json:"name" binding:"required"`
Type DocumentType `json:"type" binding:"required"`
}
Name string `json:"name" binding:"required"`
Type DocumentType `json:"type" binding:"required"`
}
type UpdateStateRequest struct {
State []byte `json:"state" binding:"required"`
}
type DocumentListResponse struct {
Documents []Document `json:"documents"`
Total int `json:"total"`
}
type UpdateStateRequest struct {
State []byte `json:"state" binding:"required"`
}
type DocumentListResponse struct {
Documents []Document `json:"documents"`
Total int `json:"total"`
}

View File

@@ -7,30 +7,30 @@ import (
)
type DocumentShare struct {
ID uuid.UUID `json:"id"`
DocumentID uuid.UUID `json:"document_id"`
UserID uuid.UUID `json:"user_id"`
Permission string `json:"permission"` // "view" or "edit"
CreatedAt time.Time `json:"created_at"`
CreatedBy *uuid.UUID `json:"created_by"`
ID uuid.UUID `json:"id"`
DocumentID uuid.UUID `json:"document_id"`
UserID uuid.UUID `json:"user_id"`
Permission string `json:"permission"` // "view" or "edit"
CreatedAt time.Time `json:"created_at"`
CreatedBy *uuid.UUID `json:"created_by"`
}
type CreateShareRequest struct {
UserEmail string `json:"user_email" binding:"required"`
Permission string `json:"permission" binding:"required,oneof=view edit"`
UserEmail string `json:"user_email" binding:"required"`
Permission string `json:"permission" binding:"required,oneof=view edit"`
}
type ShareListResponse struct {
Shares []DocumentShareWithUser `json:"shares"`
Shares []DocumentShareWithUser `json:"shares"`
}
type DocumentShareWithUser struct {
DocumentShare
User User `json:"user"`
DocumentShare
User User `json:"user"`
}
// PermissionResponse represents the user's permission level for a document
type PermissionResponse struct {
Permission string `json:"permission"` // "view" or "edit"
Role string `json:"role"` // "owner", "editor", or "viewer"
Permission string `json:"permission"` // "view" or "edit"
Role string `json:"role"` // "owner", "editor", or "viewer"
}

View File

@@ -7,42 +7,42 @@ import (
)
type User struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
AvatarURL *string `json:"avatar_url"`
Provider string `json:"provider"`
ProviderUserID string `json:"-"` // Don't expose
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
LastLoginAt *time.Time `json:"last_login_at"`
ID uuid.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
AvatarURL *string `json:"avatar_url"`
Provider string `json:"provider"`
ProviderUserID string `json:"-"` // Don't expose
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
LastLoginAt *time.Time `json:"last_login_at"`
}
type Session struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
TokenHash string `json:"-"` // SHA-256 hash of JWT
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
UserAgent *string `json:"user_agent"`
IPAddress *string `json:"ip_address"`
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
TokenHash string `json:"-"` // SHA-256 hash of JWT
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
UserAgent *string `json:"user_agent"`
IPAddress *string `json:"ip_address"`
}
type OAuthToken struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Provider string `json:"provider"`
AccessToken string `json:"-"` // Don't expose
RefreshToken *string `json:"-"`
TokenType string `json:"token_type"`
ExpiresAt time.Time `json:"expires_at"`
Scope *string `json:"scope"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Provider string `json:"provider"`
AccessToken string `json:"-"` // Don't expose
RefreshToken *string `json:"-"`
TokenType string `json:"token_type"`
ExpiresAt time.Time `json:"expires_at"`
Scope *string `json:"scope"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Response for /auth/me endpoint
type UserResponse struct {
User *User `json:"user"`
Token string `json:"token,omitempty"`
User *User `json:"user"`
Token string `json:"token,omitempty"`
}

View File

@@ -10,8 +10,8 @@ import (
type DocumentVersion struct {
ID uuid.UUID `json:"id"`
DocumentID uuid.UUID `json:"document_id"`
YjsSnapshot []byte `json:"-"` // Omit from JSON (binary)
TextPreview *string `json:"text_preview"` // Full plain text
YjsSnapshot []byte `json:"-"` // Omit from JSON (binary)
TextPreview *string `json:"text_preview"` // Full plain text
VersionNumber int `json:"version_number"`
CreatedBy *uuid.UUID `json:"created_by"`
VersionLabel *string `json:"version_label"`

View File

@@ -13,33 +13,33 @@ import (
// Store interface defines all database operations
type Store interface {
// Document operations
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
GetDocument(id uuid.UUID) (*models.Document, error)
ListDocuments() ([]models.Document, error)
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
UpdateDocumentState(id uuid.UUID, state []byte) error
DeleteDocument(id uuid.UUID) error
// User operations
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
// Session operations
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
DeleteSession(ctx context.Context, token string) error
CleanupExpiredSessions(ctx context.Context) error
// Share operations
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
// Document operations
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
GetDocument(id uuid.UUID) (*models.Document, error)
ListDocuments() ([]models.Document, error)
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
UpdateDocumentState(id uuid.UUID, state []byte) error
DeleteDocument(id uuid.UUID) error
// User operations
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
// Session operations
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
DeleteSession(ctx context.Context, token string) error
CleanupExpiredSessions(ctx context.Context) error
// Share operations
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error)
ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error)
RevokeShareToken(ctx context.Context, documentID uuid.UUID) error
@@ -53,13 +53,11 @@ type Store interface {
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
Close() error
Close() error
}
type PostgresStore struct {
db *sql.DB
db *sql.DB
}
func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
@@ -68,17 +66,17 @@ func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
return nil, error
}
if err := db.Ping(); err != nil {
return nil, fmt.Errorf("failed to ping database: %w", err)
}
return nil, fmt.Errorf("failed to ping database: %w", err)
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
return &PostgresStore{db: db}, nil
}
func (s *PostgresStore) Close() error {
return s.db.Close()
}
return s.db.Close()
}
func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
doc := &models.Document{
@@ -95,12 +93,11 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
RETURNING id, name, type, created_at, updated_at
`
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.CreatedAt,
doc.ID,
doc.Name,
doc.Type,
doc.CreatedAt,
doc.UpdatedAt,
).Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
@@ -109,164 +106,163 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
return doc, nil
}
// GetDocument retrieves a document by ID
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
doc := &models.Document{}
// GetDocument retrieves a document by ID
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
doc := &models.Document{}
query := `
query := `
SELECT id, name, type, yjs_state, owner_id, is_public, created_at, updated_at
FROM documents
WHERE id = $1
`
err := s.db.QueryRow(query, id).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.YjsState,
&doc.OwnerID,
&doc.Is_Public,
&doc.CreatedAt,
&doc.UpdatedAt,
)
err := s.db.QueryRow(query, id).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.YjsState,
&doc.OwnerID,
&doc.Is_Public,
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, fmt.Errorf("document not found")
}
if err != nil {
return nil, fmt.Errorf("failed to get document: %w", err)
}
if err == sql.ErrNoRows {
return nil, fmt.Errorf("document not found")
}
if err != nil {
return nil, fmt.Errorf("failed to get document: %w", err)
}
return doc, nil
}
return doc, nil
}
// ListDocuments retrieves all documents
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
query := `
// ListDocuments retrieves all documents
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
query := `
SELECT id, name, type, created_at, updated_at
FROM documents
ORDER BY created_at DESC
`
rows, err := s.db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to list documents: %w", err)
}
defer rows.Close()
rows, err := s.db.Query(query)
if err != nil {
return nil, fmt.Errorf("failed to list documents: %w", err)
}
defer rows.Close()
var documents []models.Document
for rows.Next() {
var doc models.Document
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan document: %w", err)
}
documents = append(documents, doc)
}
var documents []models.Document
for rows.Next() {
var doc models.Document
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan document: %w", err)
}
documents = append(documents, doc)
}
return documents, nil
}
return documents, nil
}
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
query := `
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
query := `
UPDATE documents
SET yjs_state = $1, updated_at = $2
WHERE id = $3
`
result, err := s.db.Exec(query, state, time.Now(), id)
if err != nil {
return fmt.Errorf("failed to update document state: %w", err)
}
result, err := s.db.Exec(query, state, time.Now(), id)
if err != nil {
return fmt.Errorf("failed to update document state: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
return nil
}
return nil
}
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
query := `DELETE FROM documents WHERE id = $1`
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
query := `DELETE FROM documents WHERE id = $1`
result, err := s.db.Exec(query, id)
if err != nil {
return fmt.Errorf("failed to delete document: %w", err)
}
result, err := s.db.Exec(query, id)
if err != nil {
return fmt.Errorf("failed to delete document: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("failed to get rows affected: %w", err)
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
if rowsAffected == 0 {
return fmt.Errorf("document not found")
}
return nil
}
return nil
}
// CreateDocumentWithOwner creates a new document with owner
func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) {
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
if docType == "" {
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
}
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
if docType == "" {
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
}
// Validate that docType is one of the allowed values
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
}
// Validate that docType is one of the allowed values
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
}
doc := &models.Document{
ID: uuid.New(),
Name: name,
Type: docType,
YjsState: []byte{}, // 这里初始化了空字节
OwnerID: ownerID,
Is_Public: false, // 显式设置默认值
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
doc := &models.Document{
ID: uuid.New(),
Name: name,
Type: docType,
YjsState: []byte{}, // 这里初始化了空字节
OwnerID: ownerID,
Is_Public: false, // 显式设置默认值
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 2. 补全了 yjs_state 和 is_public
query := `
// 2. 补全了 yjs_state 和 is_public
query := `
INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at
`
// 3. Scan 的时候也要对应加上
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.OwnerID,
doc.YjsState, // $5
doc.Is_Public, // $6
doc.CreatedAt,
doc.UpdatedAt,
).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.OwnerID,
&doc.YjsState, // Scan 回来
&doc.Is_Public, // Scan 回来
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create document: %w", err)
}
return doc, nil
// 3. Scan 的时候也要对应加上
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.OwnerID,
doc.YjsState, // $5
doc.Is_Public, // $6
doc.CreatedAt,
doc.UpdatedAt,
).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.OwnerID,
&doc.YjsState, // Scan 回来
&doc.Is_Public, // Scan 回来
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create document: %w", err)
}
return doc, nil
}
// ListUserDocuments lists documents owned by or shared with a user

View File

@@ -12,77 +12,77 @@ import (
// CreateSession creates a new session
func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) {
// Hash the token before storing
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
// Hash the token before storing
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
// 【修改点 1】: 在 SQL 里显式加上 id 字段
// 注意:$1 变成了 id后面的参数序号全部要顺延 (+1)
query := `
// 【修改点 1】: 在 SQL 里显式加上 id 字段
// 注意:$1 变成了 id后面的参数序号全部要顺延 (+1)
query := `
INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
`
var session models.Session
// 【修改点 2】: 在参数列表的最前面加上 sessionID
// 现在的对应关系:
// $1 -> sessionID
// $2 -> userID
// $3 -> tokenHash
// ...
err := s.db.QueryRowContext(ctx, query,
sessionID, // <--- 这里!把它传进去!
userID,
tokenHash,
expiresAt,
userAgent,
ipAddress,
).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
var session models.Session
// 【修改点 2】: 在参数列表的最前面加上 sessionID
// 现在的对应关系:
// $1 -> sessionID
// $2 -> userID
// $3 -> tokenHash
// ...
err := s.db.QueryRowContext(ctx, query,
sessionID, // <--- 这里!把它传进去!
userID,
tokenHash,
expiresAt,
userAgent,
ipAddress,
).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
return &session, nil
return &session, nil
}
// GetSessionByToken retrieves session by JWT token
func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
query := `
query := `
SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
FROM sessions
WHERE token_hash = $1 AND expires_at > NOW()
`
var session models.Session
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
var session models.Session
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
return &session, nil
return &session, nil
}
// DeleteSession deletes a session (logout)
func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
return err
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
return err
}
// CleanupExpiredSessions removes expired sessions
func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
return err
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
return err
}

View File

@@ -14,34 +14,34 @@ import (
// CreateDocumentShare creates a new share or updates existing one
// Returns the share and a boolean indicating if it was newly created (true) or updated (false)
func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) {
// First check if share already exists
var existingID uuid.UUID
checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2`
err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID)
isNewShare := err != nil // If error (not found), it's a new share
// First check if share already exists
var existingID uuid.UUID
checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2`
err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID)
isNewShare := err != nil // If error (not found), it's a new share
query := `
query := `
INSERT INTO document_shares (document_id, user_id, permission, created_by)
VALUES ($1, $2, $3, $4)
ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission
RETURNING id, document_id, user_id, permission, created_at, created_by
`
var share models.DocumentShare
err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission,
&share.CreatedAt, &share.CreatedBy,
)
if err != nil {
return nil, false, err
}
var share models.DocumentShare
err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission,
&share.CreatedAt, &share.CreatedBy,
)
if err != nil {
return nil, false, err
}
return &share, isNewShare, nil
return &share, isNewShare, nil
}
// ListDocumentShares lists all shares for a document
func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) {
query := `
query := `
SELECT
ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by,
u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at
@@ -51,38 +51,38 @@ func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.
ORDER BY ds.created_at DESC
`
rows, err := s.db.QueryContext(ctx, query, documentID)
if err != nil {
return nil, err
}
defer rows.Close()
rows, err := s.db.QueryContext(ctx, query, documentID)
if err != nil {
return nil, err
}
defer rows.Close()
var shares []models.DocumentShareWithUser
for rows.Next() {
var share models.DocumentShareWithUser
err := rows.Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
)
if err != nil {
return nil, err
}
shares = append(shares, share)
}
var shares []models.DocumentShareWithUser
for rows.Next() {
var share models.DocumentShareWithUser
err := rows.Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
)
if err != nil {
return nil, err
}
shares = append(shares, share)
}
return shares, nil
return shares, nil
}
// DeleteDocumentShare deletes a share
func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
return err
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
return err
}
// CanViewDocument checks if user can view document (owner OR has any share)
func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := `
query := `
SELECT EXISTS(
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
UNION
@@ -90,14 +90,14 @@ func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID
)
`
var canView bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
return canView, err
var canView bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
return canView, err
}
// CanEditDocument checks if user can edit document (owner OR has edit share)
func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := `
query := `
SELECT EXISTS(
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
UNION
@@ -105,21 +105,21 @@ func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID
)
`
var canEdit bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
return canEdit, err
var canEdit bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
return canEdit, err
}
// IsDocumentOwner checks if user is the owner
func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
var isOwner bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
if err == sql.ErrNoRows {
return false, nil
}
return isOwner, err
var isOwner bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
if err == sql.ErrNoRows {
return false, nil
}
return isOwner, err
}
func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) {
// Generate random 32-byte token
@@ -239,4 +239,4 @@ func (s *PostgresStore) GetShareLinkPermission(ctx context.Context, documentID u
}
return permission, nil
}
}

View File

@@ -11,7 +11,7 @@ import (
// UpsertUser creates or updates user from OAuth profile
func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) {
query := `
query := `
INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at)
VALUES ($1, $2, $3, $4, $5, NOW())
ON CONFLICT (provider, provider_user_id)
@@ -24,59 +24,59 @@ func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID
RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
`
var user models.User
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err != nil {
return nil, err
}
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
var user models.User
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err != nil {
return nil, err
}
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
return &user, nil
return &user, nil
}
// GetUserByID retrieves user by ID
func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
query := `
query := `
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
FROM users WHERE id = $1
`
var user models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
var user models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &user, nil
return &user, nil
}
// GetUserByEmail retrieves user by email
func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
query := `
query := `
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
FROM users WHERE email = $1
`
var user models.User
err := s.db.QueryRowContext(ctx, query, email).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
var user models.User
err := s.db.QueryRowContext(ctx, query, email).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &user, nil
return &user, nil
}