feat: Enhance real-time collaboration features with user awareness and document sharing

- Added user information (UserID, UserName, UserAvatar) to Client struct for presence tracking.
- Implemented failure handling in the broadcastMessage function to manage send failures and disconnect clients if necessary.
- Introduced document ownership and sharing capabilities:
  - Added OwnerID and Is_Public fields to Document model.
  - Created DocumentShare model for managing document sharing with permissions.
  - Implemented functions for creating, listing, and managing document shares in the Postgres store.
- Added user management functionality:
  - Created User model and associated functions for user management in the Postgres store.
  - Implemented session management with token hashing for security.
- Updated database schema with migrations for users, sessions, and document shares.
- Enhanced frontend Yjs integration with awareness event logging for user connections and disconnections.
This commit is contained in:
M1ngdaXie
2026-01-03 12:59:53 -08:00
parent 37d89b13b9
commit 7f5f32179b
21 changed files with 2064 additions and 232 deletions

View File

@@ -0,0 +1,63 @@
package auth
import (
"errors"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// UserClaims defines the custom claims structure
// Senior Tip: Embed information that helps you avoid DB lookups later.
type UserClaims struct {
Name string `json:"user_name"`
Email string `json:"user_email"`
AvatarURL *string `json:"avatar_url"` // Nullable avatar URL to avoid DB queries
jwt.RegisteredClaims
}
// GenerateJWT creates a stateless JWT token for a user
// Changed: Input is now userID (and optional role), not sessionID
func GenerateJWT(userID uuid.UUID, name string, email string, avatarURL *string, secret string, expiresIn time.Duration) (string, error) {
claims := UserClaims{
Name: name,
Email: email,
AvatarURL: avatarURL,
RegisteredClaims: jwt.RegisteredClaims{
// Standard claim "Subject" is technically where UserID belongs,
// but having a typed UserID field is easier for Go type assertions.
Subject: userID.String(),
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiresIn)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: "realtime-collab", // Your app name
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString([]byte(secret))
}
// ValidateJWT parses the token and extracts the UserClaims
// Changed: Returns *UserClaims so you can access UserID and Role directly
func ValidateJWT(tokenString, secret string) (*UserClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
// Security Check: Always validate the signing algorithm
// to prevent "None" algorithm attacks.
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, errors.New("invalid signing method")
}
return []byte(secret), nil
})
if err != nil {
return nil, err
}
// Type assertion to get our custom struct back
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
return claims, nil
}
return nil, errors.New("invalid token claims")
}

View File

@@ -0,0 +1,193 @@
package auth
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
type contextKey string
const UserContextKey contextKey = "user"
const ContextUserIDKey = "user_id"
// AuthMiddleware provides auth middleware
type AuthMiddleware struct {
store store.Store
jwtSecret string
}
// NewAuthMiddleware creates a new auth middleware
func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware {
return &AuthMiddleware{
store: store,
jwtSecret: jwtSecret,
}
}
// RequireAuth middleware requires valid authentication
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
fmt.Println("🔒 RequireAuth: Starting authentication check")
user, claims, err := m.getUserFromToken(c)
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
if claims != nil {
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
}
if err != nil || user == nil {
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user)
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
// Note: Name and Email might be empty for old JWT tokens
if claims.Name == "" || claims.Email == "" {
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
}
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
c.Set(ContextUserIDKey, user)
c.Set("user_email", claims.Email)
c.Set("user_name", claims.Name)
if claims.AvatarURL != nil {
c.Set("avatar_url", *claims.AvatarURL)
}
c.Next()
}
}
// OptionalAuth middleware sets user if authenticated, but doesn't require it
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
return func(c *gin.Context) {
user, claims, _ := m.getUserFromToken(c)
if user != nil {
c.Set(string(UserContextKey), user)
c.Set(ContextUserIDKey, user)
if claims != nil {
c.Set("user_email", claims.Email)
c.Set("user_name", claims.Name)
if claims.AvatarURL != nil {
c.Set("avatar_url", *claims.AvatarURL)
}
}
}
c.Next()
}
}
// getUserFromToken parses the JWT and returns the UserID and the full Claims (for name/email)
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
authHeader := c.GetHeader("Authorization")
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
if authHeader == "" {
fmt.Println("⚠️ getUserFromToken: No Authorization header")
return nil, nil, nil
}
parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" {
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
return nil, nil, nil
}
tokenString := parts[1]
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
// 必须要验证签名算法是 HMAC (HS256)
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(m.jwtSecret), nil
})
if err != nil {
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
return nil, nil, err
}
// 2. 验证 Token 有效性并提取 Claims
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
// 3. 把 String 类型的 Subject 转回 UUID
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
userID, err := uuid.Parse(claims.Subject)
if err != nil {
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
return nil, nil, fmt.Errorf("invalid user ID in token")
}
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
// 这一步完全没有查数据库,速度极快
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
return &userID, claims, nil
}
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
return nil, nil, fmt.Errorf("invalid token claims")
}
// GetUserFromContext extracts user ID from context
func GetUserFromContext(c *gin.Context) *uuid.UUID {
// 修正点:使用和存入时完全一样的 Key
val, exists := c.Get(ContextUserIDKey)
fmt.Println("within getFromContext the id is ... ")
fmt.Println(val);
if !exists {
return nil
}
// 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型)
uid, ok := val.(*uuid.UUID)
if !ok {
return nil
}
return uid
}
// ValidateToken validates a JWT token and returns user ID, name, and avatar URL from JWT claims
func (m *AuthMiddleware) ValidateToken(tokenString string) (*uuid.UUID, string, string, error) {
// Parse and validate JWT
claims, err := ValidateJWT(tokenString, m.jwtSecret)
if err != nil {
return nil, "", "", fmt.Errorf("invalid token: %w", err)
}
// Parse user ID from claims
userID, err := uuid.Parse(claims.Subject)
if err != nil {
return nil, "", "", fmt.Errorf("invalid user ID in token: %w", err)
}
// Get session from database by token (for revocation capability)
session, err := m.store.GetSessionByToken(context.Background(), tokenString)
if err != nil {
return nil, "", "", fmt.Errorf("session not found: %w", err)
}
// Verify session UserID matches JWT Subject
if session.UserID != userID {
return nil, "", "", fmt.Errorf("session ID mismatch")
}
// Extract avatar URL from claims (handle nil gracefully)
avatarURL := ""
if claims.AvatarURL != nil {
avatarURL = *claims.AvatarURL
}
// Return user data from JWT claims - no DB query needed!
return &userID, claims.Name, avatarURL, nil
}

View File

@@ -0,0 +1,32 @@
package auth
import (
"golang.org/x/oauth2"
"golang.org/x/oauth2/github"
"golang.org/x/oauth2/google"
)
// GetGoogleOAuthConfig returns Google OAuth2 config
func GetGoogleOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
return &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
}
}
// GetGitHubOAuthConfig returns GitHub OAuth2 config
func GetGitHubOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
return &oauth2.Config{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURL: redirectURL,
Scopes: []string{"user:email"},
Endpoint: github.Endpoint,
}
}

View File

@@ -0,0 +1,302 @@
package handlers
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"os"
"time"
"github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"golang.org/x/oauth2"
)
type AuthHandler struct {
store store.Store
googleConfig *oauth2.Config
githubConfig *oauth2.Config
jwtSecret string
frontendURL string
}
func NewAuthHandler(store store.Store, jwtSecret, frontendURL string) *AuthHandler {
googleConfig := auth.GetGoogleOAuthConfig(
os.Getenv("GOOGLE_CLIENT_ID"),
os.Getenv("GOOGLE_CLIENT_SECRET"),
os.Getenv("GOOGLE_REDIRECT_URL"),
)
githubConfig := auth.GetGitHubOAuthConfig(
os.Getenv("GITHUB_CLIENT_ID"),
os.Getenv("GITHUB_CLIENT_SECRET"),
os.Getenv("GITHUB_REDIRECT_URL"),
)
return &AuthHandler{
store: store,
googleConfig: googleConfig,
githubConfig: githubConfig,
jwtSecret: jwtSecret,
frontendURL: frontendURL,
}
}
// GoogleLogin redirects to Google OAuth
func (h *AuthHandler) GoogleLogin(c *gin.Context) {
// Generate random state and set cookie
oauthState := generateStateOauthCookie(c.Writer)
url := h.googleConfig.AuthCodeURL(oauthState, oauth2.AccessTypeOffline)
c.Redirect(http.StatusTemporaryRedirect, url)
}
// GoogleCallback handles Google OAuth callback
func (h *AuthHandler) GoogleCallback(c *gin.Context) {
oauthState, err := c.Cookie("oauthstate")
if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return
}
log.Println("Google callback state:", c.Query("state"))
// Exchange code for token
token, err := h.googleConfig.Exchange(context.Background(), c.Query("code"))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange token"})
return
}
// Get user info from Google
client := h.googleConfig.Client(context.Background(), token)
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
return
}
log.Println("Google user info response status:", resp.Status)
log.Println("Google user info response headers:", resp.Header)
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
var userInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Picture string `json:"picture"`
}
json.Unmarshal(data, &userInfo)
log.Println("Google user info:", userInfo)
// Upsert user in database
user, err := h.store.UpsertUser(
c.Request.Context(),
"google",
userInfo.ID,
userInfo.Email,
userInfo.Name,
&userInfo.Picture,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}
// Create session and JWT
jwt, err := h.createSessionAndJWT(c, user)
if err != nil {
fmt.Printf("❌ DATABASE ERROR: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("CreateSession Error: %v", err),
})
return
}
// Redirect to frontend with token
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s", h.frontendURL, jwt)
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
}
// GithubLogin redirects to GitHub OAuth
func (h *AuthHandler) GithubLogin(c *gin.Context) {
url := h.githubConfig.AuthCodeURL("state", oauth2.AccessTypeOffline)
c.Redirect(http.StatusTemporaryRedirect, url)
}
// GithubCallback handles GitHub OAuth callback
func (h *AuthHandler) GithubCallback(c *gin.Context) {
code := c.Query("code")
if code == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "No code provided"})
return
}
// Exchange code for token
token, err := h.githubConfig.Exchange(context.Background(), code)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange token"})
return
}
// Get user info from GitHub
client := h.githubConfig.Client(context.Background(), token)
// Get user profile
resp, err := client.Get("https://api.github.com/user")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
return
}
defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body)
var userInfo struct {
ID int `json:"id"`
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url"`
}
json.Unmarshal(data, &userInfo)
// If email is not public, fetch it separately
if userInfo.Email == "" {
emailResp, _ := client.Get("https://api.github.com/user/emails")
if emailResp != nil {
defer emailResp.Body.Close()
emailData, _ := io.ReadAll(emailResp.Body)
var emails []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
json.Unmarshal(emailData, &emails)
for _, e := range emails {
if e.Primary {
userInfo.Email = e.Email
break
}
}
}
}
// Use login as name if name is empty
if userInfo.Name == "" {
userInfo.Name = userInfo.Login
}
// Upsert user in database
user, err := h.store.UpsertUser(
c.Request.Context(),
"github",
fmt.Sprintf("%d", userInfo.ID),
userInfo.Email,
userInfo.Name,
&userInfo.AvatarURL,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
return
}
// Create session and JWT
jwt, err := h.createSessionAndJWT(c, user)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"})
return
}
// Redirect to frontend with token
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s", h.frontendURL, jwt)
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
}
// Me returns current user info
func (h *AuthHandler) Me(c *gin.Context) {
userID := auth.GetUserFromContext(c)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
user, err := h.store.GetUserByID(c.Request.Context(), *userID)
if err != nil || user == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
return
}
c.JSON(http.StatusOK, models.UserResponse{User: user})
}
// Logout invalidates the session
func (h *AuthHandler) Logout(c *gin.Context) {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
c.JSON(http.StatusOK, gin.H{"message": "Already logged out"})
return
}
// Extract token
token := ""
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
token = authHeader[7:]
}
if token != "" {
h.store.DeleteSession(c.Request.Context(), token)
}
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
}
// Helper: create session and JWT
func (h *AuthHandler) createSessionAndJWT(c *gin.Context, user *models.User) (string, error) {
expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7 days
// Generate JWT first (we need it for session) - now includes avatar URL
jwt, err := auth.GenerateJWT(user.ID, user.Name, user.Email, user.AvatarURL, h.jwtSecret, 7*24*time.Hour)
if err != nil {
return "", err
}
// Create session in database
sessionID := uuid.New()
userAgent := c.GetHeader("User-Agent")
ipAddress := c.ClientIP()
_, err = h.store.CreateSession(
c.Request.Context(),
user.ID,
sessionID,
jwt,
expiresAt,
&userAgent,
&ipAddress,
)
if err != nil {
return "", err
}
return jwt, nil
}
func generateStateOauthCookie(w http.ResponseWriter) string {
b := make([]byte, 16)
rand.Read(b)
state := base64.URLEncoding.EncodeToString(b)
cookie := http.Cookie{
Name: "oauthstate",
Value: state,
Expires: time.Now().Add(10 * time.Minute),
HttpOnly: true, // Prevents JavaScript access (XSS protection)
Secure: false, // Must be false for http://localhost (set true in production)
SameSite: http.SameSiteLaxMode, // ✅ Allows same-site OAuth redirects
Path: "/", // ✅ Ensures cookie is sent to all backend paths
}
http.SetCookie(w, &cookie)
return state
}

View File

@@ -1,8 +1,10 @@
package handlers
import (
"fmt"
"net/http"
"github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/gin-gonic/gin"
@@ -10,135 +12,199 @@ import (
)
type DocumentHandler struct {
store *store.Store
store *store.PostgresStore
}
func NewDocumentHandler(s *store.Store) *DocumentHandler {
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
return &DocumentHandler{store: s}
}
// CreateDocument creates a new document
func (h *DocumentHandler) CreateDocument(c *gin.Context) {
var req models.CreateDocumentRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Validate document type
if req.Type != models.DocumentTypeEditor && req.Type != models.DocumentTypeKanban {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document type"})
return
}
// CreateDocument creates a new document (requires auth)
func (h *DocumentHandler) CreateDocument(c *gin.Context) {
fmt.Println("getting userId right now.... ")
userID := auth.GetUserFromContext(c)
fmt.Println(userID)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
doc, err := h.store.CreateDocument(req.Name, req.Type)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var req models.CreateDocumentRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusCreated, doc)
}
// Create document with owner_id
doc, err := h.store.CreateDocumentWithOwner(req.Name, req.Type, userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to create document: %v", err)})
return
}
c.JSON(http.StatusCreated, doc)
}
// ListDocuments returns all documents
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
documents, err := h.store.ListDocuments()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
userID := auth.GetUserFromContext(c)
if documents == nil {
documents = []models.Document{}
}
var docs []models.Document
var err error
if userID != nil {
// Authenticated: show owned + shared documents
docs, err = h.store.ListUserDocuments(c.Request.Context(), *userID)
} else {
c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("we dont know you: %v", err)})
}
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list documents"})
return
}
c.JSON(http.StatusOK, models.DocumentListResponse{
Documents: docs,
Total: len(docs),
})
}
c.JSON(http.StatusOK, models.DocumentListResponse{
Documents: documents,
Total: len(documents),
})
}
// GetDocument returns a single document
func (h *DocumentHandler) GetDocument(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"})
return
}
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
doc, err := h.store.GetDocument(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "document not found"})
return
}
userID := auth.GetUserFromContext(c)
c.JSON(http.StatusOK, doc)
}
// Check permission if authenticated
if userID != nil {
canView, err := h.store.CanViewDocument(c.Request.Context(), id, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"})
return
}
if !canView {
c.JSON(http.StatusForbidden, gin.H{"error": "Access denied"})
return
}
}else{
c.JSON("this file is not public")
return
}
doc, err := h.store.GetDocument(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Document not found"})
return
}
c.JSON(http.StatusOK, doc)
}
// GetDocumentState returns the Yjs state for a document
func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"})
return
}
// GetDocumentState retrieves document state (requires view permission)
func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
doc, err := h.store.GetDocument(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "document not found"})
return
}
userID := auth.GetUserFromContext(c)
// Return binary state
if doc.YjsState == nil {
c.Data(http.StatusOK, "application/octet-stream", []byte{})
return
}
// Check permission if authenticated
if userID != nil {
canView, err := h.store.CanViewDocument(c.Request.Context(), id, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"})
return
}
if !canView {
c.JSON(http.StatusForbidden, gin.H{"error": "Access denied"})
return
}
}
c.Data(http.StatusOK, "application/octet-stream", doc.YjsState)
}
doc, err := h.store.GetDocument(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "Document not found"})
return
}
// UpdateDocumentState updates the Yjs state for a document
func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"})
return
}
c.Data(http.StatusOK, "application/octet-stream", doc.YjsState)
}
// Read binary body
state, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
// UpdateDocumentState updates document state (requires edit permission)
func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
err = h.store.UpdateDocumentState(id, state)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
userID := auth.GetUserFromContext(c)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "state updated successfully"})
}
// Check edit permission
canEdit, err := h.store.CanEditDocument(c.Request.Context(), id, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"})
return
}
if !canEdit {
c.JSON(http.StatusForbidden, gin.H{"error": "Edit access denied"})
return
}
// DeleteDocument deletes a document
func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
idStr := c.Param("id")
id, err := uuid.Parse(idStr)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"})
return
}
var req models.UpdateStateRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
err = h.store.DeleteDocument(id)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": "document not found"})
return
}
if err := h.store.UpdateDocumentState(id, req.State); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update state"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "document deleted successfully"})
}
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
}
// DeleteDocument deletes a document (owner only)
func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
userID := auth.GetUserFromContext(c)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// Check ownership
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), id, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
return
}
if !isOwner {
c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can delete documents"})
return
}
if err := h.store.DeleteDocument(id); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete document"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"})
}

View File

@@ -0,0 +1,286 @@
package handlers
import (
"fmt"
"net/http"
"os" // Add this
"github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
type ShareHandler struct {
store store.Store
}
func NewShareHandler(store store.Store) *ShareHandler {
return &ShareHandler{store: store}
}
// CreateShare creates a new document share
func (h *ShareHandler) CreateShare(c *gin.Context) {
userID := auth.GetUserFromContext(c)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
documentID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
// Check if user is owner
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
return
}
if !isOwner {
c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can share documents"})
return
}
var req models.CreateShareRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Get user by email
targetUser, err := h.store.GetUserByEmail(c.Request.Context(), req.UserEmail)
if err != nil || targetUser == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
return
}
// Create share
share, err := h.store.CreateDocumentShare(
c.Request.Context(),
documentID,
targetUser.ID,
req.Permission,
userID,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create share"})
return
}
c.JSON(http.StatusCreated, share)
}
// ListShares lists all shares for a document
func (h *ShareHandler) ListShares(c *gin.Context) {
userID := auth.GetUserFromContext(c)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
documentID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
// Check if user is owner
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
return
}
if !isOwner {
c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can view shares"})
return
}
shares, err := h.store.ListDocumentShares(c.Request.Context(), documentID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list shares"})
return
}
c.JSON(http.StatusOK, models.ShareListResponse{Shares: shares})
}
// DeleteShare removes a share
func (h *ShareHandler) DeleteShare(c *gin.Context) {
userID := auth.GetUserFromContext(c)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
documentID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
targetUserID, err := uuid.Parse(c.Param("userId"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
return
}
// Check if user is owner
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
return
}
if !isOwner {
c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can delete shares"})
return
}
err = h.store.DeleteDocumentShare(c.Request.Context(), documentID, targetUserID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete share"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Share deleted successfully"})
}
// CreateShareLink generates a public share link
func (h *ShareHandler) CreateShareLink(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
userID := auth.GetUserFromContext(c)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// Check if user is owner
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
return
}
if !isOwner {
c.JSON(http.StatusForbidden, gin.H{"error": "Only document owner can create share links"})
return
}
// Parse request body
var req struct {
Permission string `json:"permission" binding:"required,oneof=view edit"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Permission must be 'view' or 'edit'"})
return
}
// Generate share token
token, err := h.store.GenerateShareToken(c.Request.Context(), documentID, req.Permission)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate share link"})
return
}
// Get frontend URL from env
frontendURL := os.Getenv("FRONTEND_URL")
if frontendURL == "" {
frontendURL = "http://localhost:5173"
}
shareURL := fmt.Sprintf("%s/editor/%s?share=%s", frontendURL, documentID.String(), token)
c.JSON(http.StatusOK, gin.H{
"url": shareURL,
"token": token,
"permission": req.Permission,
})
}
// GetShareLink retrieves the current public share link
func (h *ShareHandler) GetShareLink(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
userID := auth.GetUserFromContext(c)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// Check if user is owner
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
return
}
if !isOwner {
c.JSON(http.StatusForbidden, gin.H{"error": "Only document owner can view share links"})
return
}
token, exists, err := h.store.GetShareToken(c.Request.Context(), documentID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get share link"})
return
}
if !exists {
c.JSON(http.StatusNotFound, gin.H{"error": "No public share link exists"})
return
}
frontendURL := os.Getenv("FRONTEND_URL")
if frontendURL == "" {
frontendURL = "http://localhost:5173"
}
shareURL := fmt.Sprintf("%s/editor/%s?share=%s", frontendURL, documentID.String(), token)
c.JSON(http.StatusOK, gin.H{
"url": shareURL,
"token": token,
})
}
// RevokeShareLink removes the public share link
func (h *ShareHandler) RevokeShareLink(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
userID := auth.GetUserFromContext(c)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
return
}
// Check if user is owner
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
return
}
if !isOwner {
c.JSON(http.StatusForbidden, gin.H{"error": "Only document owner can revoke share links"})
return
}
err = h.store.RevokeShareToken(c.Request.Context(), documentID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to revoke share link"})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Share link revoked"})
}

View File

@@ -3,57 +3,147 @@ package handlers
import (
"log"
"net/http"
"os"
"github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/hub"
"github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// Allow all origins for development
// TODO: Restrict in production
return true
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
// Check origin against allowed origins from environment
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
if allowedOrigins == "" {
// Default for development
origin := r.Header.Get("Origin")
return origin == "http://localhost:5173" || origin == "http://localhost:3000"
}
// Production: validate against ALLOWED_ORIGINS
// TODO: Parse and validate origin
return true
},
}
type WebSocketHandler struct {
hub *hub.Hub
}
type WebSocketHandler struct {
hub *hub.Hub
store store.Store
}
func NewWebSocketHandler(h *hub.Hub) *WebSocketHandler {
return &WebSocketHandler{hub: h}
}
func NewWebSocketHandler(h *hub.Hub, s store.Store) *WebSocketHandler {
return &WebSocketHandler{
hub: h,
store: s,
}
}
func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context){
func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
roomID := c.Param("roomId")
if(roomID == ""){
if roomID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "roomId is required"})
return
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
// Parse document ID
documentID, err := uuid.Parse(roomID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upgrade to WebSocket"})
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
return
}
// Create a new client
clientID := uuid.New().String()
client := hub.NewClient(clientID, conn, wsh.hub, roomID)
// Register client with hub
wsh.hub.Register <- client
// Try to authenticate via JWT token or share token
var userID *uuid.UUID
var userName string
var userAvatar *string
authenticated := false
// Start read and write pumps in separate goroutines
go client.WritePump()
go client.ReadPump()
// Check for JWT token in query parameter
jwtToken := c.Query("token")
if jwtToken != "" {
// Validate JWT and get user data from token claims (no DB query!)
jwtSecret := os.Getenv("JWT_SECRET")
if jwtSecret == "" {
log.Println("JWT_SECRET not configured")
c.JSON(http.StatusInternalServerError, gin.H{"error": "Server configuration error"})
return
}
log.Printf("WebSocket connection established for client %s in room %s", clientID, roomID)
}
authMiddleware := auth.NewAuthMiddleware(wsh.store, jwtSecret)
uid, name, avatar, err := authMiddleware.ValidateToken(jwtToken)
if err == nil && uid != nil {
// User data comes directly from JWT claims - no DB query needed!
userID = uid
userName = name
if avatar != "" {
userAvatar = &avatar
}
authenticated = true
}
}
// If not authenticated via JWT, check for share token
if !authenticated {
shareToken := c.Query("share")
if shareToken != "" {
// Validate share token
valid, err := wsh.store.ValidateShareToken(c.Request.Context(), documentID, shareToken)
if err != nil {
log.Printf("Error validating share token: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to validate share token"})
return
}
if !valid {
c.JSON(http.StatusForbidden, gin.H{"error": "Invalid or expired share token"})
return
}
// Share token is valid, allow connection with anonymous user
userName = "Anonymous"
authenticated = true
}
}
// If still not authenticated, reject connection
if !authenticated {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required. Provide 'token' or 'share' query parameter"})
return
}
// If authenticated with JWT, check document permissions
if userID != nil {
canView, err := wsh.store.CanViewDocument(c.Request.Context(), documentID, *userID)
if err != nil {
log.Printf("Error checking permissions: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"})
return
}
if !canView {
c.JSON(http.StatusForbidden, gin.H{"error": "You don't have permission to access this document"})
return
}
}
// Upgrade connection
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
log.Printf("Failed to upgrade connection: %v", err)
return
}
// Create client with user information
clientID := uuid.New().String()
client := hub.NewClient(clientID, userID, userName, userAvatar, conn, wsh.hub, roomID)
// Register client
wsh.hub.Register <- client
// Start goroutines
go client.WritePump()
go client.ReadPump()
log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID)
}

View File

@@ -3,7 +3,9 @@ package hub
import (
"log"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
)
@@ -14,11 +16,20 @@ type Message struct {
}
type Client struct {
ID string
Conn *websocket.Conn
send chan []byte
hub *Hub
roomID string
ID string
UserID *uuid.UUID // Authenticated user ID (nil for public share access)
UserName string // User's display name for presence
UserAvatar *string // User's avatar URL for presence
Conn *websocket.Conn
send chan []byte
sendMu sync.Mutex
sendClosed bool
hub *Hub
roomID string
mutex sync.Mutex
unregisterOnce sync.Once
failureCount int
failureMu sync.Mutex
}
type Room struct {
ID string
@@ -74,54 +85,99 @@ func (h *Hub) registerClient(client *Client) {
log.Printf("Client %s joined room %s (total clients: %d)", client.ID, client.roomID, len(room.clients))
}
func (h *Hub) unregisterClient(client *Client) {
h.mu.Lock()
defer h.mu.Unlock()
h.mu.Lock()
defer h.mu.Unlock()
room, exists := h.rooms[client.roomID]
if !exists {
log.Printf("Room %s does not exist for client %s", client.roomID, client.ID)
return
}
room.mu.Lock()
if _, ok := room.clients[client]; ok {
delete(room.clients, client)
close(client.send)
log.Printf("Client %s disconnected from room %s", client.ID, client.roomID)
}
room, exists := h.rooms[client.roomID]
if !exists {
log.Printf("Room %s does not exist for client %s", client.roomID, client.ID)
return
}
room.mu.Unlock()
log.Printf("Client %s left room %s (total clients: %d)", client.ID, client.roomID, len(room.clients))
room.mu.Lock()
defer room.mu.Unlock()
if len(room.clients) == 0 {
delete(h.rooms, client.roomID)
log.Printf("Deleted empty room with ID: %s", client.roomID)
}
if _, ok := room.clients[client]; ok {
delete(room.clients, client)
// Safely close send channel exactly once
client.sendMu.Lock()
if !client.sendClosed {
close(client.send)
client.sendClosed = true
}
client.sendMu.Unlock()
log.Printf("Client %s disconnected from room %s (total clients: %d)",
client.ID, client.roomID, len(room.clients))
}
if len(room.clients) == 0 {
delete(h.rooms, client.roomID)
log.Printf("Deleted empty room with ID: %s", client.roomID)
}
}
const (
writeWait = 10 * time.Second
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10 // 54 seconds
maxSendFailures = 5
)
func (h *Hub) broadcastMessage(message *Message) {
h.mu.RLock()
room, exists := h.rooms[message.RoomID]
h.mu.RUnlock()
if !exists {
log.Printf("Room %s does not exist for broadcasting", message.RoomID)
return
}
h.mu.RLock()
room, exists := h.rooms[message.RoomID]
h.mu.RUnlock()
if !exists {
log.Printf("Room %s does not exist for broadcasting", message.RoomID)
return
}
room.mu.RLock()
defer room.mu.RUnlock()
for client := range room.clients {
if client != message.sender {
select {
case client.send <- message.Data:
default:
log.Printf("Failed to send to client %s (channel full)", client.ID)
}
}
}
room.mu.RLock()
defer room.mu.RUnlock()
for client := range room.clients {
if client != message.sender {
select {
case client.send <- message.Data:
// Success - reset failure count
client.failureMu.Lock()
client.failureCount = 0
client.failureMu.Unlock()
default:
// Failed - increment failure count
client.failureMu.Lock()
client.failureCount++
currentFailures := client.failureCount
client.failureMu.Unlock()
log.Printf("Failed to send to client %s (channel full, failures: %d/%d)",
client.ID, currentFailures, maxSendFailures)
// Disconnect if threshold exceeded
if currentFailures >= maxSendFailures {
log.Printf("Client %s exceeded max send failures, disconnecting", client.ID)
go func(c *Client) {
c.unregister()
c.Conn.Close()
}(client)
}
}
}
}
}
func (c *Client) ReadPump() {
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
c.Conn.SetPongHandler(func(string) error {
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
defer func() {
c.hub.Unregister <- c
c.unregister()
c.Conn.Close()
}()
for {
@@ -141,24 +197,54 @@ func (c *Client) ReadPump() {
}
func (c *Client) WritePump() {
defer func() {
c.Conn.Close()
}()
for message := range c.send {
err := c.Conn.WriteMessage(websocket.BinaryMessage, message)
if err != nil {
log.Printf("Error writing message to client %s: %v", c.ID, err)
break
}
}
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.unregister() // NEW: Now WritePump also unregisters
c.Conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
// Hub closed the channel
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
err := c.Conn.WriteMessage(websocket.BinaryMessage, message)
if err != nil {
log.Printf("Error writing message to client %s: %v", c.ID, err)
return
}
case <-ticker.C:
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
log.Printf("Ping failed for client %s: %v", c.ID, err)
return
}
}
}
}
func NewClient(id string, conn *websocket.Conn, hub *Hub, roomID string) *Client {
func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string, conn *websocket.Conn, hub *Hub, roomID string) *Client {
return &Client{
ID: id,
Conn: conn,
send: make(chan []byte, 256),
hub: hub,
roomID: roomID,
ID: id,
UserID: userID,
UserName: userName,
UserAvatar: userAvatar,
Conn: conn,
send: make(chan []byte, 256),
hub: hub,
roomID: roomID,
}
}
func (c *Client) unregister() {
c.unregisterOnce.Do(func() {
c.hub.Unregister <- c
})
}

View File

@@ -14,14 +14,17 @@ const (
)
type Document struct {
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Type DocumentType `json:"type"`
YjsState []byte `json:"-"` // Don't expose binary data in JSON
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
ID uuid.UUID `json:"id"`
Name string `json:"name"`
Type DocumentType `json:"type"`
YjsState []byte `json:"-"`
OwnerID *uuid.UUID `json:"owner_id"` // NEW
Is_Public bool `json:"is_public"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type CreateDocumentRequest struct {
Name string `json:"name" binding:"required"`
Type DocumentType `json:"type" binding:"required"`

View File

@@ -0,0 +1,30 @@
package models
import (
"time"
"github.com/google/uuid"
)
type DocumentShare struct {
ID uuid.UUID `json:"id"`
DocumentID uuid.UUID `json:"document_id"`
UserID uuid.UUID `json:"user_id"`
Permission string `json:"permission"` // "view" or "edit"
CreatedAt time.Time `json:"created_at"`
CreatedBy *uuid.UUID `json:"created_by"`
}
type CreateShareRequest struct {
UserEmail string `json:"user_email" binding:"required"`
Permission string `json:"permission" binding:"required,oneof=view edit"`
}
type ShareListResponse struct {
Shares []DocumentShareWithUser `json:"shares"`
}
type DocumentShareWithUser struct {
DocumentShare
User User `json:"user"`
}

View File

@@ -0,0 +1,48 @@
package models
import (
"time"
"github.com/google/uuid"
)
type User struct {
ID uuid.UUID `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
AvatarURL *string `json:"avatar_url"`
Provider string `json:"provider"`
ProviderUserID string `json:"-"` // Don't expose
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
LastLoginAt *time.Time `json:"last_login_at"`
}
type Session struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
TokenHash string `json:"-"` // SHA-256 hash of JWT
ExpiresAt time.Time `json:"expires_at"`
CreatedAt time.Time `json:"created_at"`
UserAgent *string `json:"user_agent"`
IPAddress *string `json:"ip_address"`
}
type OAuthToken struct {
ID uuid.UUID `json:"id"`
UserID uuid.UUID `json:"user_id"`
Provider string `json:"provider"`
AccessToken string `json:"-"` // Don't expose
RefreshToken *string `json:"-"`
TokenType string `json:"token_type"`
ExpiresAt time.Time `json:"expires_at"`
Scope *string `json:"scope"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Response for /auth/me endpoint
type UserResponse struct {
User *User `json:"user"`
Token string `json:"token,omitempty"`
}

View File

@@ -1,6 +1,7 @@
package store
import (
"context"
"database/sql"
"fmt"
"time"
@@ -10,11 +11,49 @@ import (
_ "github.com/lib/pq" // PostgreSQL driver
)
type Store struct{
db *sql.DB
// Store interface defines all database operations
type Store interface {
// Document operations
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
GetDocument(id uuid.UUID) (*models.Document, error)
ListDocuments() ([]models.Document, error)
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
UpdateDocumentState(id uuid.UUID, state []byte) error
DeleteDocument(id uuid.UUID) error
// User operations
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
// Session operations
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
DeleteSession(ctx context.Context, token string) error
CleanupExpiredSessions(ctx context.Context) error
// Share operations
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, error)
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error)
ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error)
RevokeShareToken(ctx context.Context, documentID uuid.UUID) error
GetShareToken(ctx context.Context, documentID uuid.UUID) (string, bool, error)
Close() error
}
func NewStore(databaseUrl string) (*Store, error) {
type PostgresStore struct {
db *sql.DB
}
func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
db, error := sql.Open("postgres", databaseUrl)
if error != nil {
return nil, error
@@ -25,14 +64,14 @@ func NewStore(databaseUrl string) (*Store, error) {
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(5 * time.Minute)
return &Store{db: db}, nil
return &PostgresStore{db: db}, nil
}
func (s *Store) Close() error {
func (s *PostgresStore) Close() error {
return s.db.Close()
}
func (s *Store) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
doc := &models.Document{
ID: uuid.New(),
Name: name,
@@ -62,7 +101,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model
}
// GetDocument retrieves a document by ID
func (s *Store) GetDocument(id uuid.UUID) (*models.Document, error) {
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
doc := &models.Document{}
query := `
@@ -92,7 +131,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model
// ListDocuments retrieves all documents
func (s *Store) ListDocuments() ([]models.Document, error) {
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
query := `
SELECT id, name, type, created_at, updated_at
FROM documents
@@ -118,7 +157,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model
return documents, nil
}
func (s *Store) UpdateDocumentState(id uuid.UUID, state []byte) error {
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
query := `
UPDATE documents
SET yjs_state = $1, updated_at = $2
@@ -142,7 +181,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model
return nil
}
func (s *Store) DeleteDocument(id uuid.UUID) error {
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
query := `DELETE FROM documents WHERE id = $1`
result, err := s.db.Exec(query, id)
@@ -162,3 +201,88 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model
return nil
}
// CreateDocumentWithOwner creates a new document with owner
func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) {
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
if docType == "" {
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
}
// Validate that docType is one of the allowed values
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
}
doc := &models.Document{
ID: uuid.New(),
Name: name,
Type: docType,
YjsState: []byte{}, // 这里初始化了空字节
OwnerID: ownerID,
Is_Public: false, // 显式设置默认值
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
// 2. 补全了 yjs_state 和 is_public
query := `
INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at
`
// 3. Scan 的时候也要对应加上
err := s.db.QueryRow(query,
doc.ID,
doc.Name,
doc.Type,
doc.OwnerID,
doc.YjsState, // $5
doc.Is_Public, // $6
doc.CreatedAt,
doc.UpdatedAt,
).Scan(
&doc.ID,
&doc.Name,
&doc.Type,
&doc.OwnerID,
&doc.YjsState, // Scan 回来
&doc.Is_Public, // Scan 回来
&doc.CreatedAt,
&doc.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("failed to create document: %w", err)
}
return doc, nil
}
// ListUserDocuments lists documents owned by or shared with a user
func (s *PostgresStore) ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) {
query := `
SELECT DISTINCT d.id, d.name, d.type, d.owner_id, d.created_at, d.updated_at
FROM documents d
LEFT JOIN document_shares ds ON d.id = ds.document_id
WHERE d.owner_id = $1 OR ds.user_id = $1
ORDER BY d.created_at DESC
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, fmt.Errorf("failed to list user documents: %w", err)
}
defer rows.Close()
var documents []models.Document
for rows.Next() {
var doc models.Document
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.OwnerID, &doc.CreatedAt, &doc.UpdatedAt)
if err != nil {
return nil, fmt.Errorf("failed to scan document: %w", err)
}
documents = append(documents, doc)
}
return documents, nil
}

View File

@@ -0,0 +1,88 @@
package store
import (
"context"
"crypto/sha256"
"encoding/hex"
"time"
"github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/google/uuid"
)
// CreateSession creates a new session
func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) {
// Hash the token before storing
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
// 【修改点 1】: 在 SQL 里显式加上 id 字段
// 注意:$1 变成了 id后面的参数序号全部要顺延 (+1)
query := `
INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address)
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
`
var session models.Session
// 【修改点 2】: 在参数列表的最前面加上 sessionID
// 现在的对应关系:
// $1 -> sessionID
// $2 -> userID
// $3 -> tokenHash
// ...
err := s.db.QueryRowContext(ctx, query,
sessionID, // <--- 这里!把它传进去!
userID,
tokenHash,
expiresAt,
userAgent,
ipAddress,
).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
return &session, nil
}
// GetSessionByToken retrieves session by JWT token
func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
query := `
SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
FROM sessions
WHERE token_hash = $1 AND expires_at > NOW()
`
var session models.Session
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
)
if err != nil {
return nil, err
}
return &session, nil
}
// DeleteSession deletes a session (logout)
func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error {
hash := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(hash[:])
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
return err
}
// CleanupExpiredSessions removes expired sessions
func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
return err
}

View File

@@ -0,0 +1,193 @@
package store
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"fmt"
"github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/google/uuid"
)
// CreateDocumentShare creates a new share
func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, error) {
query := `
INSERT INTO document_shares (document_id, user_id, permission, created_by)
VALUES ($1, $2, $3, $4)
ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission
RETURNING id, document_id, user_id, permission, created_at, created_by
`
var share models.DocumentShare
err := s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission,
&share.CreatedAt, &share.CreatedBy,
)
if err != nil {
return nil, err
}
return &share, nil
}
// ListDocumentShares lists all shares for a document
func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) {
query := `
SELECT
ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by,
u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at
FROM document_shares ds
JOIN users u ON ds.user_id = u.id
WHERE ds.document_id = $1
ORDER BY ds.created_at DESC
`
rows, err := s.db.QueryContext(ctx, query, documentID)
if err != nil {
return nil, err
}
defer rows.Close()
var shares []models.DocumentShareWithUser
for rows.Next() {
var share models.DocumentShareWithUser
err := rows.Scan(
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
)
if err != nil {
return nil, err
}
shares = append(shares, share)
}
return shares, nil
}
// DeleteDocumentShare deletes a share
func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error {
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
return err
}
// CanViewDocument checks if user can view document (owner OR has any share)
func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := `
SELECT EXISTS(
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
UNION
SELECT 1 FROM document_shares WHERE document_id = $1 AND user_id = $2
)
`
var canView bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
return canView, err
}
// CanEditDocument checks if user can edit document (owner OR has edit share)
func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := `
SELECT EXISTS(
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
UNION
SELECT 1 FROM document_shares WHERE document_id = $1 AND user_id = $2 AND permission = 'edit'
)
`
var canEdit bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
return canEdit, err
}
// IsDocumentOwner checks if user is the owner
func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
var isOwner bool
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
if err == sql.ErrNoRows {
return false, nil
}
return isOwner, err
}
func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) {
// Generate random 32-byte token
tokenBytes := make([]byte, 32)
if _, err := rand.Read(tokenBytes); err != nil {
return "", fmt.Errorf("failed to generate token: %w", err)
}
token := base64.URLEncoding.EncodeToString(tokenBytes)
// Update document with share token
query := `
UPDATE documents
SET share_token = $1, is_public = true, updated_at = NOW()
WHERE id = $2
RETURNING share_token
`
var shareToken string
err := s.db.QueryRowContext(ctx, query, token, documentID).Scan(&shareToken)
if err != nil {
return "", fmt.Errorf("failed to set share token: %w", err)
}
return shareToken, nil
}
// ValidateShareToken checks if a share token is valid for a document
func (s *PostgresStore) ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error) {
query := `
SELECT EXISTS(
SELECT 1 FROM documents
WHERE id = $1 AND share_token = $2 AND is_public = true
)
`
var exists bool
err := s.db.QueryRowContext(ctx, query, documentID, token).Scan(&exists)
if err != nil {
return false, fmt.Errorf("failed to validate share token: %w", err)
}
return exists, nil
}
// RevokeShareToken removes the public share link from a document
func (s *PostgresStore) RevokeShareToken(ctx context.Context, documentID uuid.UUID) error {
query := `
UPDATE documents
SET share_token = NULL, is_public = false, updated_at = NOW()
WHERE id = $1
`
_, err := s.db.ExecContext(ctx, query, documentID)
if err != nil {
return fmt.Errorf("failed to revoke share token: %w", err)
}
return nil
}
// GetShareToken retrieves the current share token for a document (if exists)
func (s *PostgresStore) GetShareToken(ctx context.Context, documentID uuid.UUID) (string, bool, error) {
query := `
SELECT share_token FROM documents
WHERE id = $1 AND is_public = true AND share_token IS NOT NULL
`
var token string
err := s.db.QueryRowContext(ctx, query, documentID).Scan(&token)
if err == sql.ErrNoRows {
return "", false, nil
}
if err != nil {
return "", false, fmt.Errorf("failed to get share token: %w", err)
}
return token, true, nil
}

View File

@@ -0,0 +1,82 @@
package store
import (
"context"
"database/sql"
"fmt"
"github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/google/uuid"
)
// UpsertUser creates or updates user from OAuth profile
func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) {
query := `
INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at)
VALUES ($1, $2, $3, $4, $5, NOW())
ON CONFLICT (provider, provider_user_id)
DO UPDATE SET
email = EXCLUDED.email,
name = EXCLUDED.name,
avatar_url = EXCLUDED.avatar_url,
last_login_at = NOW(),
updated_at = NOW()
RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
`
var user models.User
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err != nil {
return nil, err
}
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
return &user, nil
}
// GetUserByID retrieves user by ID
func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
query := `
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
FROM users WHERE id = $1
`
var user models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &user, nil
}
// GetUserByEmail retrieves user by email
func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
query := `
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
FROM users WHERE email = $1
`
var user models.User
err := s.db.QueryRowContext(ctx, query, email).Scan(
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &user, nil
}