Refactor and improve code consistency across multiple files
- Enhanced SQL queries in `session.go` and `share.go` for clarity and consistency. - Updated comments for better understanding and maintenance. - Ensured consistent error handling and return statements across various methods.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -33,3 +33,5 @@ build/
|
|||||||
|
|
||||||
# Docker volumes and data
|
# Docker volumes and data
|
||||||
postgres_data/
|
postgres_data/
|
||||||
|
|
||||||
|
.claude/
|
||||||
@@ -18,142 +18,142 @@ const ContextUserIDKey = "user_id"
|
|||||||
|
|
||||||
// AuthMiddleware provides auth middleware
|
// AuthMiddleware provides auth middleware
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
store store.Store
|
store store.Store
|
||||||
jwtSecret string
|
jwtSecret string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware creates a new auth middleware
|
// NewAuthMiddleware creates a new auth middleware
|
||||||
func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware {
|
func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware {
|
||||||
return &AuthMiddleware{
|
return &AuthMiddleware{
|
||||||
store: store,
|
store: store,
|
||||||
jwtSecret: jwtSecret,
|
jwtSecret: jwtSecret,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequireAuth middleware requires valid authentication
|
// RequireAuth middleware requires valid authentication
|
||||||
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
|
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
fmt.Println("🔒 RequireAuth: Starting authentication check")
|
fmt.Println("🔒 RequireAuth: Starting authentication check")
|
||||||
|
|
||||||
user, claims, err := m.getUserFromToken(c)
|
user, claims, err := m.getUserFromToken(c)
|
||||||
|
|
||||||
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
|
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
|
||||||
if claims != nil {
|
if claims != nil {
|
||||||
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
|
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil || user == nil {
|
if err != nil || user == nil {
|
||||||
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user)
|
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user)
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: Name and Email might be empty for old JWT tokens
|
// Note: Name and Email might be empty for old JWT tokens
|
||||||
if claims.Name == "" || claims.Email == "" {
|
if claims.Name == "" || claims.Email == "" {
|
||||||
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
|
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
|
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
|
||||||
c.Set(ContextUserIDKey, user)
|
c.Set(ContextUserIDKey, user)
|
||||||
c.Set("user_email", claims.Email)
|
c.Set("user_email", claims.Email)
|
||||||
c.Set("user_name", claims.Name)
|
c.Set("user_name", claims.Name)
|
||||||
if claims.AvatarURL != nil {
|
if claims.AvatarURL != nil {
|
||||||
c.Set("avatar_url", *claims.AvatarURL)
|
c.Set("avatar_url", *claims.AvatarURL)
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OptionalAuth middleware sets user if authenticated, but doesn't require it
|
// OptionalAuth middleware sets user if authenticated, but doesn't require it
|
||||||
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
|
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
user, claims, _ := m.getUserFromToken(c)
|
user, claims, _ := m.getUserFromToken(c)
|
||||||
if user != nil {
|
if user != nil {
|
||||||
c.Set(string(UserContextKey), user)
|
c.Set(string(UserContextKey), user)
|
||||||
c.Set(ContextUserIDKey, user)
|
c.Set(ContextUserIDKey, user)
|
||||||
if claims != nil {
|
if claims != nil {
|
||||||
c.Set("user_email", claims.Email)
|
c.Set("user_email", claims.Email)
|
||||||
c.Set("user_name", claims.Name)
|
c.Set("user_name", claims.Name)
|
||||||
if claims.AvatarURL != nil {
|
if claims.AvatarURL != nil {
|
||||||
c.Set("avatar_url", *claims.AvatarURL)
|
c.Set("avatar_url", *claims.AvatarURL)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.Next()
|
c.Next()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getUserFromToken parses the JWT and returns the UserID and the full Claims (for name/email)
|
// getUserFromToken parses the JWT and returns the UserID and the full Claims (for name/email)
|
||||||
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
|
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
|
||||||
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
|
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
|
||||||
authHeader := c.GetHeader("Authorization")
|
authHeader := c.GetHeader("Authorization")
|
||||||
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
|
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
|
||||||
|
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
fmt.Println("⚠️ getUserFromToken: No Authorization header")
|
fmt.Println("⚠️ getUserFromToken: No Authorization header")
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(authHeader, " ")
|
parts := strings.Split(authHeader, " ")
|
||||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||||
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
|
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenString := parts[1]
|
tokenString := parts[1]
|
||||||
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
|
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
|
||||||
|
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
// 必须要验证签名算法是 HMAC (HS256)
|
// 必须要验证签名算法是 HMAC (HS256)
|
||||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||||
}
|
}
|
||||||
return []byte(m.jwtSecret), nil
|
return []byte(m.jwtSecret), nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
|
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 验证 Token 有效性并提取 Claims
|
// 2. 验证 Token 有效性并提取 Claims
|
||||||
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
|
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
|
||||||
// 3. 把 String 类型的 Subject 转回 UUID
|
// 3. 把 String 类型的 Subject 转回 UUID
|
||||||
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
|
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
|
||||||
userID, err := uuid.Parse(claims.Subject)
|
userID, err := uuid.Parse(claims.Subject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
|
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
|
||||||
return nil, nil, fmt.Errorf("invalid user ID in token")
|
return nil, nil, fmt.Errorf("invalid user ID in token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
|
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
|
||||||
// 这一步完全没有查数据库,速度极快
|
// 这一步完全没有查数据库,速度极快
|
||||||
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
|
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
|
||||||
return &userID, claims, nil
|
return &userID, claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
|
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
|
||||||
return nil, nil, fmt.Errorf("invalid token claims")
|
return nil, nil, fmt.Errorf("invalid token claims")
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserFromContext extracts user ID from context
|
// GetUserFromContext extracts user ID from context
|
||||||
func GetUserFromContext(c *gin.Context) *uuid.UUID {
|
func GetUserFromContext(c *gin.Context) *uuid.UUID {
|
||||||
// 修正点:使用和存入时完全一样的 Key
|
// 修正点:使用和存入时完全一样的 Key
|
||||||
val, exists := c.Get(ContextUserIDKey)
|
val, exists := c.Get(ContextUserIDKey)
|
||||||
fmt.Println("within getFromContext the id is ... ")
|
fmt.Println("within getFromContext the id is ... ")
|
||||||
fmt.Println(val);
|
fmt.Println(val)
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型)
|
// 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型)
|
||||||
uid, ok := val.(*uuid.UUID)
|
uid, ok := val.(*uuid.UUID)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return uid
|
return uid
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateToken validates a JWT token and returns user ID, name, and avatar URL from JWT claims
|
// ValidateToken validates a JWT token and returns user ID, name, and avatar URL from JWT claims
|
||||||
|
|||||||
@@ -8,25 +8,25 @@ import (
|
|||||||
|
|
||||||
// GetGoogleOAuthConfig returns Google OAuth2 config
|
// GetGoogleOAuthConfig returns Google OAuth2 config
|
||||||
func GetGoogleOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
|
func GetGoogleOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
|
||||||
return &oauth2.Config{
|
return &oauth2.Config{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
||||||
RedirectURL: redirectURL,
|
RedirectURL: redirectURL,
|
||||||
Scopes: []string{
|
Scopes: []string{
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
"https://www.googleapis.com/auth/userinfo.email",
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
"https://www.googleapis.com/auth/userinfo.profile",
|
||||||
},
|
},
|
||||||
Endpoint: google.Endpoint,
|
Endpoint: google.Endpoint,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetGitHubOAuthConfig returns GitHub OAuth2 config
|
// GetGitHubOAuthConfig returns GitHub OAuth2 config
|
||||||
func GetGitHubOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
|
func GetGitHubOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
|
||||||
return &oauth2.Config{
|
return &oauth2.Config{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
ClientSecret: clientSecret,
|
ClientSecret: clientSecret,
|
||||||
RedirectURL: redirectURL,
|
RedirectURL: redirectURL,
|
||||||
Scopes: []string{"user:email"},
|
Scopes: []string{"user:email"},
|
||||||
Endpoint: github.Endpoint,
|
Endpoint: github.Endpoint,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,10 +64,10 @@ func (h *AuthHandler) GoogleLogin(c *gin.Context) {
|
|||||||
// GoogleCallback handles Google OAuth callback
|
// GoogleCallback handles Google OAuth callback
|
||||||
func (h *AuthHandler) GoogleCallback(c *gin.Context) {
|
func (h *AuthHandler) GoogleCallback(c *gin.Context) {
|
||||||
oauthState, err := c.Cookie("oauthstate")
|
oauthState, err := c.Cookie("oauthstate")
|
||||||
if err != nil || c.Query("state") != oauthState {
|
if err != nil || c.Query("state") != oauthState {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Println("Google callback state:", c.Query("state"))
|
log.Println("Google callback state:", c.Query("state"))
|
||||||
// Exchange code for token
|
// Exchange code for token
|
||||||
token, err := h.googleConfig.Exchange(c.Request.Context(), c.Query("code"))
|
token, err := h.googleConfig.Exchange(c.Request.Context(), c.Query("code"))
|
||||||
@@ -96,9 +96,9 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := json.Unmarshal(data, &userInfo); err != nil {
|
if err := json.Unmarshal(data, &userInfo); err != nil {
|
||||||
log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data))
|
log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Println("Google user info:", userInfo)
|
log.Println("Google user info:", userInfo)
|
||||||
// Upsert user in database
|
// Upsert user in database
|
||||||
@@ -119,9 +119,9 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
|
|||||||
jwt, err := h.createSessionAndJWT(c, user)
|
jwt, err := h.createSessionAndJWT(c, user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("❌ DATABASE ERROR: %v\n", err)
|
fmt.Printf("❌ DATABASE ERROR: %v\n", err)
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{
|
c.JSON(http.StatusInternalServerError, gin.H{
|
||||||
"error": fmt.Sprintf("CreateSession Error: %v", err),
|
"error": fmt.Sprintf("CreateSession Error: %v", err),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -140,10 +140,10 @@ func (h *AuthHandler) GithubLogin(c *gin.Context) {
|
|||||||
// GithubCallback handles GitHub OAuth callback
|
// GithubCallback handles GitHub OAuth callback
|
||||||
func (h *AuthHandler) GithubCallback(c *gin.Context) {
|
func (h *AuthHandler) GithubCallback(c *gin.Context) {
|
||||||
oauthState, err := c.Cookie("oauthstate")
|
oauthState, err := c.Cookie("oauthstate")
|
||||||
if err != nil || c.Query("state") != oauthState {
|
if err != nil || c.Query("state") != oauthState {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
|
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Println("Github callback state:", c.Query("state"))
|
log.Println("Github callback state:", c.Query("state"))
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
if code == "" {
|
if code == "" {
|
||||||
@@ -178,10 +178,10 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
|
|||||||
AvatarURL string `json:"avatar_url"`
|
AvatarURL string `json:"avatar_url"`
|
||||||
}
|
}
|
||||||
if err := json.Unmarshal(data, &userInfo); err != nil {
|
if err := json.Unmarshal(data, &userInfo); err != nil {
|
||||||
log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data))
|
log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data))
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If email is not public, fetch it separately
|
// If email is not public, fetch it separately
|
||||||
if userInfo.Email == "" {
|
if userInfo.Email == "" {
|
||||||
@@ -315,10 +315,10 @@ func (h *AuthHandler) generateStateOauthCookie(w http.ResponseWriter) string {
|
|||||||
Name: "oauthstate",
|
Name: "oauthstate",
|
||||||
Value: state,
|
Value: state,
|
||||||
Expires: time.Now().Add(10 * time.Minute),
|
Expires: time.Now().Add(10 * time.Minute),
|
||||||
HttpOnly: true, // Prevents JavaScript access (XSS protection)
|
HttpOnly: true, // Prevents JavaScript access (XSS protection)
|
||||||
Secure: h.cfg.SecureCookie, // true in production, false for localhost
|
Secure: h.cfg.SecureCookie, // true in production, false for localhost
|
||||||
SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects
|
SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects
|
||||||
Path: "/", // Ensures cookie is sent to all backend paths
|
Path: "/", // Ensures cookie is sent to all backend paths
|
||||||
}
|
}
|
||||||
http.SetCookie(w, &cookie)
|
http.SetCookie(w, &cookie)
|
||||||
|
|
||||||
|
|||||||
@@ -11,16 +11,15 @@ import (
|
|||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DocumentHandler struct {
|
type DocumentHandler struct {
|
||||||
store *store.PostgresStore
|
store *store.PostgresStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
|
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
|
||||||
return &DocumentHandler{store: s}
|
return &DocumentHandler{store: s}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CreateDocument creates a new document (requires auth)
|
||||||
// CreateDocument creates a new document (requires auth)
|
|
||||||
func (h *DocumentHandler) CreateDocument(c *gin.Context) {
|
func (h *DocumentHandler) CreateDocument(c *gin.Context) {
|
||||||
userID := auth.GetUserFromContext(c)
|
userID := auth.GetUserFromContext(c)
|
||||||
if userID == nil {
|
if userID == nil {
|
||||||
@@ -44,7 +43,7 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
|
|||||||
c.JSON(http.StatusCreated, doc)
|
c.JSON(http.StatusCreated, doc)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
|
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
|
||||||
userID := auth.GetUserFromContext(c)
|
userID := auth.GetUserFromContext(c)
|
||||||
fmt.Println("Getting userId, which is : ")
|
fmt.Println("Getting userId, which is : ")
|
||||||
fmt.Println(userID)
|
fmt.Println(userID)
|
||||||
@@ -66,8 +65,7 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *DocumentHandler) GetDocument(c *gin.Context) {
|
||||||
func (h *DocumentHandler) GetDocument(c *gin.Context) {
|
|
||||||
id, err := uuid.Parse(c.Param("id"))
|
id, err := uuid.Parse(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondBadRequest(c, "Invalid document ID format")
|
respondBadRequest(c, "Invalid document ID format")
|
||||||
@@ -104,8 +102,9 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
|
|||||||
|
|
||||||
c.JSON(http.StatusOK, doc)
|
c.JSON(http.StatusOK, doc)
|
||||||
}
|
}
|
||||||
// GetDocumentState returns the Yjs state for a document
|
|
||||||
// GetDocumentState retrieves document state (requires view permission)
|
// GetDocumentState returns the Yjs state for a document
|
||||||
|
// GetDocumentState retrieves document state (requires view permission)
|
||||||
func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
||||||
id, err := uuid.Parse(c.Param("id"))
|
id, err := uuid.Parse(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -143,7 +142,7 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
|||||||
c.Data(http.StatusOK, "application/octet-stream", state)
|
c.Data(http.StatusOK, "application/octet-stream", state)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateDocumentState updates document state (requires edit permission)
|
// UpdateDocumentState updates document state (requires edit permission)
|
||||||
func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
|
func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
|
||||||
id, err := uuid.Parse(c.Param("id"))
|
id, err := uuid.Parse(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -195,7 +194,7 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
|
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDocument deletes a document (owner only)
|
// DeleteDocument deletes a document (owner only)
|
||||||
func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
|
func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
|
||||||
id, err := uuid.Parse(c.Param("id"))
|
id, err := uuid.Parse(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -237,107 +236,109 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
|
|||||||
|
|
||||||
// GetDocumentPermission returns the user's permission level for a document
|
// GetDocumentPermission returns the user's permission level for a document
|
||||||
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
|
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
|
||||||
documentID, err := uuid.Parse(c.Param("id"))
|
documentID, err := uuid.Parse(c.Param("id"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondBadRequest(c, "Invalid document ID format")
|
respondBadRequest(c, "Invalid document ID format")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. 先检查文档是否存在 (Good practice)
|
// 1. 先检查文档是否存在 (Good practice)
|
||||||
_, err = h.store.GetDocument(documentID)
|
_, err = h.store.GetDocument(documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondNotFound(c, "document")
|
respondNotFound(c, "document")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
userID := auth.GetUserFromContext(c)
|
userID := auth.GetUserFromContext(c)
|
||||||
shareToken := c.Query("share")
|
shareToken := c.Query("share")
|
||||||
|
|
||||||
// 定义两个临时变量,用来存两边的结果
|
// 定义两个临时变量,用来存两边的结果
|
||||||
var userPerm string // 存 document_shares 的结果
|
var userPerm string // 存 document_shares 的结果
|
||||||
var tokenPerm string // 存 share_token 的结果
|
var tokenPerm string // 存 share_token 的结果
|
||||||
|
|
||||||
// ====================================================
|
// ====================================================
|
||||||
// 步骤 A: 检查个人权限 (Base Permission)
|
// 步骤 A: 检查个人权限 (Base Permission)
|
||||||
// ====================================================
|
// ====================================================
|
||||||
if userID != nil {
|
if userID != nil {
|
||||||
perm, err := h.store.GetUserPermission(c.Request.Context(), documentID, *userID)
|
perm, err := h.store.GetUserPermission(c.Request.Context(), documentID, *userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondInternalError(c, "Failed to get user permission", err)
|
respondInternalError(c, "Failed to get user permission", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
userPerm = perm
|
userPerm = perm
|
||||||
// ⚠️ 注意:如果 perm 是空,这里不报错!继续往下走!
|
// ⚠️ 注意:如果 perm 是空,这里不报错!继续往下走!
|
||||||
}
|
}
|
||||||
|
|
||||||
// ====================================================
|
// ====================================================
|
||||||
// 步骤 B: 检查 Token 权限 (Upgrade Permission)
|
// 步骤 B: 检查 Token 权限 (Upgrade Permission)
|
||||||
// ====================================================
|
// ====================================================
|
||||||
if shareToken != "" {
|
if shareToken != "" {
|
||||||
// 先验证 Token 是否有效
|
// 先验证 Token 是否有效
|
||||||
valid, err := h.store.ValidateShareToken(c.Request.Context(), documentID, shareToken)
|
valid, err := h.store.ValidateShareToken(c.Request.Context(), documentID, shareToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondInternalError(c, "Failed to validate token", err)
|
respondInternalError(c, "Failed to validate token", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 只有 Token 有效才去取权限
|
// 只有 Token 有效才去取权限
|
||||||
if valid {
|
if valid {
|
||||||
p, err := h.store.GetShareLinkPermission(c.Request.Context(), documentID)
|
p, err := h.store.GetShareLinkPermission(c.Request.Context(), documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
respondInternalError(c, "Failed to get token permission", err)
|
respondInternalError(c, "Failed to get token permission", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tokenPerm = p
|
tokenPerm = p
|
||||||
// 处理数据库老数据的 fallback
|
// 处理数据库老数据的 fallback
|
||||||
if tokenPerm == "" { tokenPerm = "view" }
|
if tokenPerm == "" {
|
||||||
}
|
tokenPerm = "view"
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ====================================================
|
// ====================================================
|
||||||
// 步骤 C: ⚡️ 权限合并与计算 (The Brain)
|
// 步骤 C: ⚡️ 权限合并与计算 (The Brain)
|
||||||
// ====================================================
|
// ====================================================
|
||||||
|
|
||||||
finalPermission := ""
|
finalPermission := ""
|
||||||
role := "viewer" // 默认角色
|
role := "viewer" // 默认角色
|
||||||
|
|
||||||
// 1. 如果是 Owner,无敌,直接返回
|
// 1. 如果是 Owner,无敌,直接返回
|
||||||
if userPerm == "owner" {
|
if userPerm == "owner" {
|
||||||
finalPermission = "edit"
|
finalPermission = "edit"
|
||||||
role = "owner"
|
role = "owner"
|
||||||
// 直接返回,不用看 Token 了
|
// 直接返回,不用看 Token 了
|
||||||
c.JSON(http.StatusOK, models.PermissionResponse{
|
c.JSON(http.StatusOK, models.PermissionResponse{
|
||||||
Permission: finalPermission,
|
Permission: finalPermission,
|
||||||
Role: role,
|
Role: role,
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 比较 User 和 Token,取最大值
|
// 2. 比较 User 和 Token,取最大值
|
||||||
// 逻辑:只要任意一边给了 "edit",那就是 "edit"
|
// 逻辑:只要任意一边给了 "edit",那就是 "edit"
|
||||||
if userPerm == "edit" || tokenPerm == "edit" {
|
if userPerm == "edit" || tokenPerm == "edit" {
|
||||||
finalPermission = "edit"
|
finalPermission = "edit"
|
||||||
role = "editor"
|
role = "editor"
|
||||||
} else if userPerm == "view" || tokenPerm == "view" {
|
} else if userPerm == "view" || tokenPerm == "view" {
|
||||||
finalPermission = "view"
|
finalPermission = "view"
|
||||||
role = "viewer"
|
role = "viewer"
|
||||||
}
|
}
|
||||||
|
|
||||||
// ====================================================
|
// ====================================================
|
||||||
// 步骤 D: 最终判决
|
// 步骤 D: 最终判决
|
||||||
// ====================================================
|
// ====================================================
|
||||||
if finalPermission == "" {
|
if finalPermission == "" {
|
||||||
// 既没个人权限,Token 也不对(或者没 Token)
|
// 既没个人权限,Token 也不对(或者没 Token)
|
||||||
if userID == nil {
|
if userID == nil {
|
||||||
respondUnauthorized(c, "Authentication required") // 没登录且没Token
|
respondUnauthorized(c, "Authentication required") // 没登录且没Token
|
||||||
} else {
|
} else {
|
||||||
respondForbidden(c, "You don't have permission") // 登录了但没权限
|
respondForbidden(c, "You don't have permission") // 登录了但没权限
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, models.PermissionResponse{
|
c.JSON(http.StatusOK, models.PermissionResponse{
|
||||||
Permission: finalPermission,
|
Permission: finalPermission,
|
||||||
Role: role,
|
Role: role,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -91,5 +91,5 @@ func respondInternalError(c *gin.Context, message string, err error) {
|
|||||||
respondWithError(c, http.StatusInternalServerError, "internal_error", message)
|
respondWithError(c, http.StatusInternalServerError, "internal_error", message)
|
||||||
}
|
}
|
||||||
func respondInvalidID(c *gin.Context, message string) {
|
func respondInvalidID(c *gin.Context, message string) {
|
||||||
respondWithError(c, http.StatusBadRequest, "invalid_id", message)
|
respondWithError(c, http.StatusBadRequest, "invalid_id", message)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -153,6 +153,7 @@ func (h *ShareHandler) DeleteShare(c *gin.Context) {
|
|||||||
|
|
||||||
c.Status(204)
|
c.Status(204)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateShareLink generates a public share link
|
// CreateShareLink generates a public share link
|
||||||
func (h *ShareHandler) CreateShareLink(c *gin.Context) {
|
func (h *ShareHandler) CreateShareLink(c *gin.Context) {
|
||||||
documentID, err := uuid.Parse(c.Param("id"))
|
documentID, err := uuid.Parse(c.Param("id"))
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ func TestShareHandlerSuite(t *testing.T) {
|
|||||||
|
|
||||||
func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() {
|
func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() {
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"user_email": "bob@test.com",
|
"user_email": "bob@test.com",
|
||||||
"permission": "view",
|
"permission": "view",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() {
|
|||||||
|
|
||||||
func (s *ShareHandlerSuite) TestCreateShare_EditPermission() {
|
func (s *ShareHandlerSuite) TestCreateShare_EditPermission() {
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"user_email": "bob@test.com",
|
"user_email": "bob@test.com",
|
||||||
"permission": "edit",
|
"permission": "edit",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -104,7 +104,7 @@ func (s *ShareHandlerSuite) TestCreateShare_EditPermission() {
|
|||||||
|
|
||||||
func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() {
|
func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() {
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"user_email": "charlie@test.com",
|
"user_email": "charlie@test.com",
|
||||||
"permission": "view",
|
"permission": "view",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,7 +118,7 @@ func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() {
|
|||||||
|
|
||||||
func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() {
|
func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() {
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"user_email": "nonexistent@test.com",
|
"user_email": "nonexistent@test.com",
|
||||||
"permission": "view",
|
"permission": "view",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,7 +131,7 @@ func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() {
|
|||||||
|
|
||||||
func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() {
|
func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() {
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"user_email": "bob@test.com",
|
"user_email": "bob@test.com",
|
||||||
"permission": "admin", // Invalid permission
|
"permission": "admin", // Invalid permission
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -145,7 +145,7 @@ func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() {
|
|||||||
func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() {
|
func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() {
|
||||||
// Create initial share with view permission
|
// Create initial share with view permission
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"user_email": "bob@test.com",
|
"user_email": "bob@test.com",
|
||||||
"permission": "view",
|
"permission": "view",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -169,7 +169,7 @@ func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() {
|
|||||||
|
|
||||||
func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() {
|
func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() {
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"user_email": "bob@test.com",
|
"user_email": "bob@test.com",
|
||||||
"permission": "view",
|
"permission": "view",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,7 +182,7 @@ func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() {
|
|||||||
|
|
||||||
func (s *ShareHandlerSuite) TestCreateShare_InvalidDocumentID() {
|
func (s *ShareHandlerSuite) TestCreateShare_InvalidDocumentID() {
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"user_email": "bob@test.com",
|
"user_email": "bob@test.com",
|
||||||
"permission": "view",
|
"permission": "view",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -206,8 +206,8 @@ func (s *ShareHandlerSuite) TestListShares_OwnerSeesAll() {
|
|||||||
s.assertSuccessResponse(w, http.StatusOK)
|
s.assertSuccessResponse(w, http.StatusOK)
|
||||||
|
|
||||||
var response models.ShareListResponse
|
var response models.ShareListResponse
|
||||||
s.parseJSONResponse(w, &response)
|
s.parseJSONResponse(w, &response)
|
||||||
shares := response.Shares
|
shares := response.Shares
|
||||||
s.GreaterOrEqual(len(shares), 1, "Should have at least one share")
|
s.GreaterOrEqual(len(shares), 1, "Should have at least one share")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -229,8 +229,8 @@ func (s *ShareHandlerSuite) TestListShares_EmptyList() {
|
|||||||
s.assertSuccessResponse(w, http.StatusOK)
|
s.assertSuccessResponse(w, http.StatusOK)
|
||||||
|
|
||||||
var response models.ShareListResponse
|
var response models.ShareListResponse
|
||||||
s.parseJSONResponse(w, &response)
|
s.parseJSONResponse(w, &response)
|
||||||
shares := response.Shares
|
shares := response.Shares
|
||||||
s.Equal(0, len(shares), "Should have no shares")
|
s.Equal(0, len(shares), "Should have no shares")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,8 +243,8 @@ func (s *ShareHandlerSuite) TestListShares_IncludesUserDetails() {
|
|||||||
s.assertSuccessResponse(w, http.StatusOK)
|
s.assertSuccessResponse(w, http.StatusOK)
|
||||||
|
|
||||||
var response models.ShareListResponse
|
var response models.ShareListResponse
|
||||||
s.parseJSONResponse(w, &response)
|
s.parseJSONResponse(w, &response)
|
||||||
shares := response.Shares
|
shares := response.Shares
|
||||||
|
|
||||||
if len(shares) > 0 {
|
if len(shares) > 0 {
|
||||||
share := shares[0]
|
share := shares[0]
|
||||||
@@ -266,7 +266,7 @@ func (s *ShareHandlerSuite) TestListShares_OrderedByCreatedAt() {
|
|||||||
users := []string{"bob@test.com", "charlie@test.com"}
|
users := []string{"bob@test.com", "charlie@test.com"}
|
||||||
for _, email := range users {
|
for _, email := range users {
|
||||||
body := map[string]interface{}{
|
body := map[string]interface{}{
|
||||||
"user_email": email,
|
"user_email": email,
|
||||||
"permission": "view",
|
"permission": "view",
|
||||||
}
|
}
|
||||||
w, httpReq, err := s.makeAuthRequest("POST", fmt.Sprintf("/api/documents/%s/shares", s.testData.AlicePrivateDoc), body, s.testData.AliceID)
|
w, httpReq, err := s.makeAuthRequest("POST", fmt.Sprintf("/api/documents/%s/shares", s.testData.AlicePrivateDoc), body, s.testData.AliceID)
|
||||||
|
|||||||
@@ -48,10 +48,10 @@ func (wsh *WebSocketHandler) HandleWebSocketLoadTest(c *gin.Context) {
|
|||||||
clientID := uuid.New().String()
|
clientID := uuid.New().String()
|
||||||
client := hub.NewClient(
|
client := hub.NewClient(
|
||||||
clientID,
|
clientID,
|
||||||
nil, // userID - nil for anonymous
|
nil, // userID - nil for anonymous
|
||||||
userName, // userName
|
userName, // userName
|
||||||
nil, // userAvatar
|
nil, // userAvatar
|
||||||
"edit", // permission - full access for load testing
|
"edit", // permission - full access for load testing
|
||||||
conn,
|
conn,
|
||||||
wsh.hub,
|
wsh.hub,
|
||||||
roomID,
|
roomID,
|
||||||
|
|||||||
@@ -14,28 +14,26 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Document struct {
|
type Document struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Type DocumentType `json:"type"`
|
Type DocumentType `json:"type"`
|
||||||
YjsState []byte `json:"-"`
|
YjsState []byte `json:"-"`
|
||||||
OwnerID *uuid.UUID `json:"owner_id"` // NEW
|
OwnerID *uuid.UUID `json:"owner_id"` // NEW
|
||||||
Is_Public bool `json:"is_public"`
|
Is_Public bool `json:"is_public"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
type CreateDocumentRequest struct {
|
type CreateDocumentRequest struct {
|
||||||
Name string `json:"name" binding:"required"`
|
Name string `json:"name" binding:"required"`
|
||||||
Type DocumentType `json:"type" binding:"required"`
|
Type DocumentType `json:"type" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateStateRequest struct {
|
type UpdateStateRequest struct {
|
||||||
State []byte `json:"state" binding:"required"`
|
State []byte `json:"state" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DocumentListResponse struct {
|
|
||||||
Documents []Document `json:"documents"`
|
|
||||||
Total int `json:"total"`
|
|
||||||
}
|
|
||||||
|
|
||||||
|
type DocumentListResponse struct {
|
||||||
|
Documents []Document `json:"documents"`
|
||||||
|
Total int `json:"total"`
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,30 +7,30 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type DocumentShare struct {
|
type DocumentShare struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
DocumentID uuid.UUID `json:"document_id"`
|
DocumentID uuid.UUID `json:"document_id"`
|
||||||
UserID uuid.UUID `json:"user_id"`
|
UserID uuid.UUID `json:"user_id"`
|
||||||
Permission string `json:"permission"` // "view" or "edit"
|
Permission string `json:"permission"` // "view" or "edit"
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
CreatedBy *uuid.UUID `json:"created_by"`
|
CreatedBy *uuid.UUID `json:"created_by"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateShareRequest struct {
|
type CreateShareRequest struct {
|
||||||
UserEmail string `json:"user_email" binding:"required"`
|
UserEmail string `json:"user_email" binding:"required"`
|
||||||
Permission string `json:"permission" binding:"required,oneof=view edit"`
|
Permission string `json:"permission" binding:"required,oneof=view edit"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ShareListResponse struct {
|
type ShareListResponse struct {
|
||||||
Shares []DocumentShareWithUser `json:"shares"`
|
Shares []DocumentShareWithUser `json:"shares"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DocumentShareWithUser struct {
|
type DocumentShareWithUser struct {
|
||||||
DocumentShare
|
DocumentShare
|
||||||
User User `json:"user"`
|
User User `json:"user"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PermissionResponse represents the user's permission level for a document
|
// PermissionResponse represents the user's permission level for a document
|
||||||
type PermissionResponse struct {
|
type PermissionResponse struct {
|
||||||
Permission string `json:"permission"` // "view" or "edit"
|
Permission string `json:"permission"` // "view" or "edit"
|
||||||
Role string `json:"role"` // "owner", "editor", or "viewer"
|
Role string `json:"role"` // "owner", "editor", or "viewer"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,42 +7,42 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type User struct {
|
type User struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
AvatarURL *string `json:"avatar_url"`
|
AvatarURL *string `json:"avatar_url"`
|
||||||
Provider string `json:"provider"`
|
Provider string `json:"provider"`
|
||||||
ProviderUserID string `json:"-"` // Don't expose
|
ProviderUserID string `json:"-"` // Don't expose
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
LastLoginAt *time.Time `json:"last_login_at"`
|
LastLoginAt *time.Time `json:"last_login_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
UserID uuid.UUID `json:"user_id"`
|
UserID uuid.UUID `json:"user_id"`
|
||||||
TokenHash string `json:"-"` // SHA-256 hash of JWT
|
TokenHash string `json:"-"` // SHA-256 hash of JWT
|
||||||
ExpiresAt time.Time `json:"expires_at"`
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UserAgent *string `json:"user_agent"`
|
UserAgent *string `json:"user_agent"`
|
||||||
IPAddress *string `json:"ip_address"`
|
IPAddress *string `json:"ip_address"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthToken struct {
|
type OAuthToken struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
UserID uuid.UUID `json:"user_id"`
|
UserID uuid.UUID `json:"user_id"`
|
||||||
Provider string `json:"provider"`
|
Provider string `json:"provider"`
|
||||||
AccessToken string `json:"-"` // Don't expose
|
AccessToken string `json:"-"` // Don't expose
|
||||||
RefreshToken *string `json:"-"`
|
RefreshToken *string `json:"-"`
|
||||||
TokenType string `json:"token_type"`
|
TokenType string `json:"token_type"`
|
||||||
ExpiresAt time.Time `json:"expires_at"`
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
Scope *string `json:"scope"`
|
Scope *string `json:"scope"`
|
||||||
CreatedAt time.Time `json:"created_at"`
|
CreatedAt time.Time `json:"created_at"`
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Response for /auth/me endpoint
|
// Response for /auth/me endpoint
|
||||||
type UserResponse struct {
|
type UserResponse struct {
|
||||||
User *User `json:"user"`
|
User *User `json:"user"`
|
||||||
Token string `json:"token,omitempty"`
|
Token string `json:"token,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,8 +10,8 @@ import (
|
|||||||
type DocumentVersion struct {
|
type DocumentVersion struct {
|
||||||
ID uuid.UUID `json:"id"`
|
ID uuid.UUID `json:"id"`
|
||||||
DocumentID uuid.UUID `json:"document_id"`
|
DocumentID uuid.UUID `json:"document_id"`
|
||||||
YjsSnapshot []byte `json:"-"` // Omit from JSON (binary)
|
YjsSnapshot []byte `json:"-"` // Omit from JSON (binary)
|
||||||
TextPreview *string `json:"text_preview"` // Full plain text
|
TextPreview *string `json:"text_preview"` // Full plain text
|
||||||
VersionNumber int `json:"version_number"`
|
VersionNumber int `json:"version_number"`
|
||||||
CreatedBy *uuid.UUID `json:"created_by"`
|
CreatedBy *uuid.UUID `json:"created_by"`
|
||||||
VersionLabel *string `json:"version_label"`
|
VersionLabel *string `json:"version_label"`
|
||||||
|
|||||||
@@ -13,33 +13,33 @@ import (
|
|||||||
|
|
||||||
// Store interface defines all database operations
|
// Store interface defines all database operations
|
||||||
type Store interface {
|
type Store interface {
|
||||||
// Document operations
|
// Document operations
|
||||||
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
|
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
|
||||||
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
|
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
|
||||||
GetDocument(id uuid.UUID) (*models.Document, error)
|
GetDocument(id uuid.UUID) (*models.Document, error)
|
||||||
ListDocuments() ([]models.Document, error)
|
ListDocuments() ([]models.Document, error)
|
||||||
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
|
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
|
||||||
UpdateDocumentState(id uuid.UUID, state []byte) error
|
UpdateDocumentState(id uuid.UUID, state []byte) error
|
||||||
DeleteDocument(id uuid.UUID) error
|
DeleteDocument(id uuid.UUID) error
|
||||||
|
|
||||||
// User operations
|
// User operations
|
||||||
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
|
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
|
||||||
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
|
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
|
||||||
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
|
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
|
||||||
|
|
||||||
// Session operations
|
// Session operations
|
||||||
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
|
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
|
||||||
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
|
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
|
||||||
DeleteSession(ctx context.Context, token string) error
|
DeleteSession(ctx context.Context, token string) error
|
||||||
CleanupExpiredSessions(ctx context.Context) error
|
CleanupExpiredSessions(ctx context.Context) error
|
||||||
|
|
||||||
// Share operations
|
// Share operations
|
||||||
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
|
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error)
|
||||||
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
|
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
|
||||||
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
|
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
|
||||||
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||||
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||||
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||||
GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error)
|
GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error)
|
||||||
ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error)
|
ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error)
|
||||||
RevokeShareToken(ctx context.Context, documentID uuid.UUID) error
|
RevokeShareToken(ctx context.Context, documentID uuid.UUID) error
|
||||||
@@ -53,13 +53,11 @@ type Store interface {
|
|||||||
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
|
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
|
||||||
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
|
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
|
||||||
|
|
||||||
|
Close() error
|
||||||
Close() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
type PostgresStore struct {
|
type PostgresStore struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
|
func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
|
||||||
@@ -68,17 +66,17 @@ func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
|
|||||||
return nil, error
|
return nil, error
|
||||||
}
|
}
|
||||||
if err := db.Ping(); err != nil {
|
if err := db.Ping(); err != nil {
|
||||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||||
}
|
}
|
||||||
db.SetMaxOpenConns(25)
|
db.SetMaxOpenConns(25)
|
||||||
db.SetMaxIdleConns(5)
|
db.SetMaxIdleConns(5)
|
||||||
db.SetConnMaxLifetime(5 * time.Minute)
|
db.SetConnMaxLifetime(5 * time.Minute)
|
||||||
return &PostgresStore{db: db}, nil
|
return &PostgresStore{db: db}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PostgresStore) Close() error {
|
func (s *PostgresStore) Close() error {
|
||||||
return s.db.Close()
|
return s.db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
|
func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
|
||||||
doc := &models.Document{
|
doc := &models.Document{
|
||||||
@@ -100,7 +98,6 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
|
|||||||
doc.Type,
|
doc.Type,
|
||||||
doc.CreatedAt,
|
doc.CreatedAt,
|
||||||
doc.UpdatedAt,
|
doc.UpdatedAt,
|
||||||
|
|
||||||
).Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
|
).Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -109,164 +106,163 @@ func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType)
|
|||||||
return doc, nil
|
return doc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDocument retrieves a document by ID
|
// GetDocument retrieves a document by ID
|
||||||
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
|
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
|
||||||
doc := &models.Document{}
|
doc := &models.Document{}
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT id, name, type, yjs_state, owner_id, is_public, created_at, updated_at
|
SELECT id, name, type, yjs_state, owner_id, is_public, created_at, updated_at
|
||||||
FROM documents
|
FROM documents
|
||||||
WHERE id = $1
|
WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
err := s.db.QueryRow(query, id).Scan(
|
err := s.db.QueryRow(query, id).Scan(
|
||||||
&doc.ID,
|
&doc.ID,
|
||||||
&doc.Name,
|
&doc.Name,
|
||||||
&doc.Type,
|
&doc.Type,
|
||||||
&doc.YjsState,
|
&doc.YjsState,
|
||||||
&doc.OwnerID,
|
&doc.OwnerID,
|
||||||
&doc.Is_Public,
|
&doc.Is_Public,
|
||||||
&doc.CreatedAt,
|
&doc.CreatedAt,
|
||||||
&doc.UpdatedAt,
|
&doc.UpdatedAt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, fmt.Errorf("document not found")
|
return nil, fmt.Errorf("document not found")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get document: %w", err)
|
return nil, fmt.Errorf("failed to get document: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return doc, nil
|
return doc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListDocuments retrieves all documents
|
||||||
// ListDocuments retrieves all documents
|
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
|
||||||
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
|
query := `
|
||||||
query := `
|
|
||||||
SELECT id, name, type, created_at, updated_at
|
SELECT id, name, type, created_at, updated_at
|
||||||
FROM documents
|
FROM documents
|
||||||
ORDER BY created_at DESC
|
ORDER BY created_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := s.db.Query(query)
|
rows, err := s.db.Query(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to list documents: %w", err)
|
return nil, fmt.Errorf("failed to list documents: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
var documents []models.Document
|
var documents []models.Document
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var doc models.Document
|
var doc models.Document
|
||||||
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
|
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.CreatedAt, &doc.UpdatedAt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to scan document: %w", err)
|
return nil, fmt.Errorf("failed to scan document: %w", err)
|
||||||
}
|
}
|
||||||
documents = append(documents, doc)
|
documents = append(documents, doc)
|
||||||
}
|
}
|
||||||
|
|
||||||
return documents, nil
|
return documents, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
|
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
|
||||||
query := `
|
query := `
|
||||||
UPDATE documents
|
UPDATE documents
|
||||||
SET yjs_state = $1, updated_at = $2
|
SET yjs_state = $1, updated_at = $2
|
||||||
WHERE id = $3
|
WHERE id = $3
|
||||||
`
|
`
|
||||||
|
|
||||||
result, err := s.db.Exec(query, state, time.Now(), id)
|
result, err := s.db.Exec(query, state, time.Now(), id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to update document state: %w", err)
|
return fmt.Errorf("failed to update document state: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
rowsAffected, err := result.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return fmt.Errorf("document not found")
|
return fmt.Errorf("document not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
|
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
|
||||||
query := `DELETE FROM documents WHERE id = $1`
|
query := `DELETE FROM documents WHERE id = $1`
|
||||||
|
|
||||||
result, err := s.db.Exec(query, id)
|
result, err := s.db.Exec(query, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to delete document: %w", err)
|
return fmt.Errorf("failed to delete document: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rowsAffected, err := result.RowsAffected()
|
rowsAffected, err := result.RowsAffected()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get rows affected: %w", err)
|
return fmt.Errorf("failed to get rows affected: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rowsAffected == 0 {
|
if rowsAffected == 0 {
|
||||||
return fmt.Errorf("document not found")
|
return fmt.Errorf("document not found")
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDocumentWithOwner creates a new document with owner
|
// CreateDocumentWithOwner creates a new document with owner
|
||||||
func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) {
|
func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) {
|
||||||
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
|
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
|
||||||
if docType == "" {
|
if docType == "" {
|
||||||
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
|
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate that docType is one of the allowed values
|
// Validate that docType is one of the allowed values
|
||||||
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
|
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
|
||||||
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
|
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
|
||||||
}
|
}
|
||||||
|
|
||||||
doc := &models.Document{
|
doc := &models.Document{
|
||||||
ID: uuid.New(),
|
ID: uuid.New(),
|
||||||
Name: name,
|
Name: name,
|
||||||
Type: docType,
|
Type: docType,
|
||||||
YjsState: []byte{}, // 这里初始化了空字节
|
YjsState: []byte{}, // 这里初始化了空字节
|
||||||
OwnerID: ownerID,
|
OwnerID: ownerID,
|
||||||
Is_Public: false, // 显式设置默认值
|
Is_Public: false, // 显式设置默认值
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
UpdatedAt: time.Now(),
|
UpdatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 补全了 yjs_state 和 is_public
|
// 2. 补全了 yjs_state 和 is_public
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at)
|
INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||||
RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at
|
RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at
|
||||||
`
|
`
|
||||||
|
|
||||||
// 3. Scan 的时候也要对应加上
|
// 3. Scan 的时候也要对应加上
|
||||||
err := s.db.QueryRow(query,
|
err := s.db.QueryRow(query,
|
||||||
doc.ID,
|
doc.ID,
|
||||||
doc.Name,
|
doc.Name,
|
||||||
doc.Type,
|
doc.Type,
|
||||||
doc.OwnerID,
|
doc.OwnerID,
|
||||||
doc.YjsState, // $5
|
doc.YjsState, // $5
|
||||||
doc.Is_Public, // $6
|
doc.Is_Public, // $6
|
||||||
doc.CreatedAt,
|
doc.CreatedAt,
|
||||||
doc.UpdatedAt,
|
doc.UpdatedAt,
|
||||||
).Scan(
|
).Scan(
|
||||||
&doc.ID,
|
&doc.ID,
|
||||||
&doc.Name,
|
&doc.Name,
|
||||||
&doc.Type,
|
&doc.Type,
|
||||||
&doc.OwnerID,
|
&doc.OwnerID,
|
||||||
&doc.YjsState, // Scan 回来
|
&doc.YjsState, // Scan 回来
|
||||||
&doc.Is_Public, // Scan 回来
|
&doc.Is_Public, // Scan 回来
|
||||||
&doc.CreatedAt,
|
&doc.CreatedAt,
|
||||||
&doc.UpdatedAt,
|
&doc.UpdatedAt,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create document: %w", err)
|
return nil, fmt.Errorf("failed to create document: %w", err)
|
||||||
}
|
}
|
||||||
return doc, nil
|
return doc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListUserDocuments lists documents owned by or shared with a user
|
// ListUserDocuments lists documents owned by or shared with a user
|
||||||
|
|||||||
@@ -12,77 +12,77 @@ import (
|
|||||||
|
|
||||||
// CreateSession creates a new session
|
// CreateSession creates a new session
|
||||||
func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) {
|
func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) {
|
||||||
// Hash the token before storing
|
// Hash the token before storing
|
||||||
hash := sha256.Sum256([]byte(token))
|
hash := sha256.Sum256([]byte(token))
|
||||||
tokenHash := hex.EncodeToString(hash[:])
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
// 【修改点 1】: 在 SQL 里显式加上 id 字段
|
// 【修改点 1】: 在 SQL 里显式加上 id 字段
|
||||||
// 注意:$1 变成了 id,后面的参数序号全部要顺延 (+1)
|
// 注意:$1 变成了 id,后面的参数序号全部要顺延 (+1)
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address)
|
INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6)
|
VALUES ($1, $2, $3, $4, $5, $6)
|
||||||
RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
|
RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
|
||||||
`
|
`
|
||||||
|
|
||||||
var session models.Session
|
var session models.Session
|
||||||
// 【修改点 2】: 在参数列表的最前面加上 sessionID
|
// 【修改点 2】: 在参数列表的最前面加上 sessionID
|
||||||
// 现在的对应关系:
|
// 现在的对应关系:
|
||||||
// $1 -> sessionID
|
// $1 -> sessionID
|
||||||
// $2 -> userID
|
// $2 -> userID
|
||||||
// $3 -> tokenHash
|
// $3 -> tokenHash
|
||||||
// ...
|
// ...
|
||||||
err := s.db.QueryRowContext(ctx, query,
|
err := s.db.QueryRowContext(ctx, query,
|
||||||
sessionID, // <--- 这里!把它传进去!
|
sessionID, // <--- 这里!把它传进去!
|
||||||
userID,
|
userID,
|
||||||
tokenHash,
|
tokenHash,
|
||||||
expiresAt,
|
expiresAt,
|
||||||
userAgent,
|
userAgent,
|
||||||
ipAddress,
|
ipAddress,
|
||||||
).Scan(
|
).Scan(
|
||||||
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
|
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
|
||||||
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
|
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &session, nil
|
return &session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSessionByToken retrieves session by JWT token
|
// GetSessionByToken retrieves session by JWT token
|
||||||
func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) {
|
func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) {
|
||||||
hash := sha256.Sum256([]byte(token))
|
hash := sha256.Sum256([]byte(token))
|
||||||
tokenHash := hex.EncodeToString(hash[:])
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
|
SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
|
||||||
FROM sessions
|
FROM sessions
|
||||||
WHERE token_hash = $1 AND expires_at > NOW()
|
WHERE token_hash = $1 AND expires_at > NOW()
|
||||||
`
|
`
|
||||||
|
|
||||||
var session models.Session
|
var session models.Session
|
||||||
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
|
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
|
||||||
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
|
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
|
||||||
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
|
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &session, nil
|
return &session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteSession deletes a session (logout)
|
// DeleteSession deletes a session (logout)
|
||||||
func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error {
|
func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error {
|
||||||
hash := sha256.Sum256([]byte(token))
|
hash := sha256.Sum256([]byte(token))
|
||||||
tokenHash := hex.EncodeToString(hash[:])
|
tokenHash := hex.EncodeToString(hash[:])
|
||||||
|
|
||||||
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
|
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// CleanupExpiredSessions removes expired sessions
|
// CleanupExpiredSessions removes expired sessions
|
||||||
func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error {
|
func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error {
|
||||||
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
|
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,34 +14,34 @@ import (
|
|||||||
// CreateDocumentShare creates a new share or updates existing one
|
// CreateDocumentShare creates a new share or updates existing one
|
||||||
// Returns the share and a boolean indicating if it was newly created (true) or updated (false)
|
// Returns the share and a boolean indicating if it was newly created (true) or updated (false)
|
||||||
func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) {
|
func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, bool, error) {
|
||||||
// First check if share already exists
|
// First check if share already exists
|
||||||
var existingID uuid.UUID
|
var existingID uuid.UUID
|
||||||
checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2`
|
checkQuery := `SELECT id FROM document_shares WHERE document_id = $1 AND user_id = $2`
|
||||||
err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID)
|
err := s.db.QueryRowContext(ctx, checkQuery, documentID, userID).Scan(&existingID)
|
||||||
isNewShare := err != nil // If error (not found), it's a new share
|
isNewShare := err != nil // If error (not found), it's a new share
|
||||||
|
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO document_shares (document_id, user_id, permission, created_by)
|
INSERT INTO document_shares (document_id, user_id, permission, created_by)
|
||||||
VALUES ($1, $2, $3, $4)
|
VALUES ($1, $2, $3, $4)
|
||||||
ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission
|
ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission
|
||||||
RETURNING id, document_id, user_id, permission, created_at, created_by
|
RETURNING id, document_id, user_id, permission, created_at, created_by
|
||||||
`
|
`
|
||||||
|
|
||||||
var share models.DocumentShare
|
var share models.DocumentShare
|
||||||
err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
|
err = s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
|
||||||
&share.ID, &share.DocumentID, &share.UserID, &share.Permission,
|
&share.ID, &share.DocumentID, &share.UserID, &share.Permission,
|
||||||
&share.CreatedAt, &share.CreatedBy,
|
&share.CreatedAt, &share.CreatedBy,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &share, isNewShare, nil
|
return &share, isNewShare, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListDocumentShares lists all shares for a document
|
// ListDocumentShares lists all shares for a document
|
||||||
func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) {
|
func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT
|
SELECT
|
||||||
ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by,
|
ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by,
|
||||||
u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at
|
u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at
|
||||||
@@ -51,38 +51,38 @@ func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.
|
|||||||
ORDER BY ds.created_at DESC
|
ORDER BY ds.created_at DESC
|
||||||
`
|
`
|
||||||
|
|
||||||
rows, err := s.db.QueryContext(ctx, query, documentID)
|
rows, err := s.db.QueryContext(ctx, query, documentID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer rows.Close()
|
||||||
|
|
||||||
var shares []models.DocumentShareWithUser
|
var shares []models.DocumentShareWithUser
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var share models.DocumentShareWithUser
|
var share models.DocumentShareWithUser
|
||||||
err := rows.Scan(
|
err := rows.Scan(
|
||||||
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
|
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
|
||||||
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
|
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
|
||||||
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
|
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
shares = append(shares, share)
|
shares = append(shares, share)
|
||||||
}
|
}
|
||||||
|
|
||||||
return shares, nil
|
return shares, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteDocumentShare deletes a share
|
// DeleteDocumentShare deletes a share
|
||||||
func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error {
|
func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error {
|
||||||
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
|
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// CanViewDocument checks if user can view document (owner OR has any share)
|
// CanViewDocument checks if user can view document (owner OR has any share)
|
||||||
func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT EXISTS(
|
SELECT EXISTS(
|
||||||
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
|
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
|
||||||
UNION
|
UNION
|
||||||
@@ -90,14 +90,14 @@ func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID
|
|||||||
)
|
)
|
||||||
`
|
`
|
||||||
|
|
||||||
var canView bool
|
var canView bool
|
||||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
|
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
|
||||||
return canView, err
|
return canView, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// CanEditDocument checks if user can edit document (owner OR has edit share)
|
// CanEditDocument checks if user can edit document (owner OR has edit share)
|
||||||
func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT EXISTS(
|
SELECT EXISTS(
|
||||||
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
|
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
|
||||||
UNION
|
UNION
|
||||||
@@ -105,21 +105,21 @@ func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID
|
|||||||
)
|
)
|
||||||
`
|
`
|
||||||
|
|
||||||
var canEdit bool
|
var canEdit bool
|
||||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
|
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
|
||||||
return canEdit, err
|
return canEdit, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsDocumentOwner checks if user is the owner
|
// IsDocumentOwner checks if user is the owner
|
||||||
func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
||||||
query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
|
query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
|
||||||
|
|
||||||
var isOwner bool
|
var isOwner bool
|
||||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
|
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return isOwner, err
|
return isOwner, err
|
||||||
}
|
}
|
||||||
func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) {
|
func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) {
|
||||||
// Generate random 32-byte token
|
// Generate random 32-byte token
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
|
|
||||||
// UpsertUser creates or updates user from OAuth profile
|
// UpsertUser creates or updates user from OAuth profile
|
||||||
func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) {
|
func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) {
|
||||||
query := `
|
query := `
|
||||||
INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at)
|
INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at)
|
||||||
VALUES ($1, $2, $3, $4, $5, NOW())
|
VALUES ($1, $2, $3, $4, $5, NOW())
|
||||||
ON CONFLICT (provider, provider_user_id)
|
ON CONFLICT (provider, provider_user_id)
|
||||||
@@ -24,59 +24,59 @@ func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID
|
|||||||
RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
||||||
`
|
`
|
||||||
|
|
||||||
var user models.User
|
var user models.User
|
||||||
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
|
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
|
||||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
|
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
|
||||||
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserByID retrieves user by ID
|
// GetUserByID retrieves user by ID
|
||||||
func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
|
func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
||||||
FROM users WHERE id = $1
|
FROM users WHERE id = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
var user models.User
|
var user models.User
|
||||||
err := s.db.QueryRowContext(ctx, query, userID).Scan(
|
err := s.db.QueryRowContext(ctx, query, userID).Scan(
|
||||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||||
)
|
)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetUserByEmail retrieves user by email
|
// GetUserByEmail retrieves user by email
|
||||||
func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
|
func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
|
||||||
query := `
|
query := `
|
||||||
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
||||||
FROM users WHERE email = $1
|
FROM users WHERE email = $1
|
||||||
`
|
`
|
||||||
|
|
||||||
var user models.User
|
var user models.User
|
||||||
err := s.db.QueryRowContext(ctx, query, email).Scan(
|
err := s.db.QueryRowContext(ctx, query, email).Scan(
|
||||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||||
)
|
)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user