From 7f5f32179b1b93272e9ea89a88c2f53500be9a4c Mon Sep 17 00:00:00 2001 From: M1ngdaXie <156019134+M1ngdaXie@users.noreply.github.com> Date: Sat, 3 Jan 2026 12:59:53 -0800 Subject: [PATCH] feat: Enhance real-time collaboration features with user awareness and document sharing - Added user information (UserID, UserName, UserAvatar) to Client struct for presence tracking. - Implemented failure handling in the broadcastMessage function to manage send failures and disconnect clients if necessary. - Introduced document ownership and sharing capabilities: - Added OwnerID and Is_Public fields to Document model. - Created DocumentShare model for managing document sharing with permissions. - Implemented functions for creating, listing, and managing document shares in the Postgres store. - Added user management functionality: - Created User model and associated functions for user management in the Postgres store. - Implemented session management with token hashing for security. - Updated database schema with migrations for users, sessions, and document shares. - Enhanced frontend Yjs integration with awareness event logging for user connections and disconnections. --- backend/go.mod | 18 +- backend/go.sum | 20 +- backend/internal/auth/jwt.go | 63 ++++ backend/internal/auth/middleware.go | 193 +++++++++++ backend/internal/auth/oauth.go | 32 ++ backend/internal/handlers/auth.go | 302 ++++++++++++++++++ backend/internal/handlers/document.go | 274 ++++++++++------ backend/internal/handlers/share.go | 286 +++++++++++++++++ backend/internal/handlers/websocket.go | 152 +++++++-- backend/internal/hub/hub.go | 204 ++++++++---- backend/internal/models/document.go | 15 +- backend/internal/models/share.go | 30 ++ backend/internal/models/user.go | 48 +++ backend/internal/store/postgres.go | 144 ++++++++- backend/internal/store/session.go | 88 +++++ backend/internal/store/share.go | 193 +++++++++++ backend/internal/store/user.go | 82 +++++ .../scripts/001_add_users_and_sessions.sql | 52 +++ backend/scripts/002_add_document_shares.sql | 19 ++ frontend/src/hooks/useYjsDocument.ts | 78 ++++- frontend/src/lib/yjs.ts | 3 +- 21 files changed, 2064 insertions(+), 232 deletions(-) create mode 100644 backend/internal/auth/jwt.go create mode 100644 backend/internal/auth/middleware.go create mode 100644 backend/internal/auth/oauth.go create mode 100644 backend/internal/handlers/auth.go create mode 100644 backend/internal/handlers/share.go create mode 100644 backend/internal/models/share.go create mode 100644 backend/internal/models/user.go create mode 100644 backend/internal/store/session.go create mode 100644 backend/internal/store/share.go create mode 100644 backend/internal/store/user.go create mode 100644 backend/scripts/001_add_users_and_sessions.sql create mode 100644 backend/scripts/002_add_document_shares.sql diff --git a/backend/go.mod b/backend/go.mod index e43a4b4..8c1f659 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -3,25 +3,31 @@ module github.com/M1ngdaXie/realtime-collab go 1.25.3 require ( + github.com/gin-contrib/cors v1.7.6 + github.com/gin-gonic/gin v1.11.0 + github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/google/uuid v1.6.0 + github.com/gorilla/websocket v1.5.3 + github.com/joho/godotenv v1.5.1 + github.com/lib/pq v1.10.9 + golang.org/x/oauth2 v0.34.0 +) + +require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/bytedance/sonic v1.14.0 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/gabriel-vasile/mimetype v1.4.9 // indirect - github.com/gin-contrib/cors v1.7.6 // indirect github.com/gin-contrib/sse v1.1.0 // indirect - github.com/gin-gonic/gin v1.11.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.27.0 // indirect github.com/goccy/go-json v0.10.5 // indirect github.com/goccy/go-yaml v1.18.0 // indirect - github.com/google/uuid v1.6.0 // indirect - github.com/gorilla/websocket v1.5.3 // indirect - github.com/joho/godotenv v1.5.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect - github.com/lib/pq v1.10.9 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect diff --git a/backend/go.sum b/backend/go.sum index fad66b9..a6d3ed9 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -1,3 +1,5 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= @@ -5,9 +7,8 @@ github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFos github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM= -github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8= github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= github.com/gin-contrib/cors v1.7.6 h1:3gQ8GMzs1Ylpf70y8bMw4fVpycXIeX1ZemuSQIsnQQY= @@ -16,18 +17,22 @@ github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4= github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo= -github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= -github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -45,7 +50,6 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -53,6 +57,7 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= @@ -65,6 +70,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA= @@ -79,6 +86,8 @@ golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -92,4 +101,5 @@ google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7I google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/backend/internal/auth/jwt.go b/backend/internal/auth/jwt.go new file mode 100644 index 0000000..9f4cb84 --- /dev/null +++ b/backend/internal/auth/jwt.go @@ -0,0 +1,63 @@ +package auth + +import ( + "errors" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +// UserClaims defines the custom claims structure +// Senior Tip: Embed information that helps you avoid DB lookups later. +type UserClaims struct { + Name string `json:"user_name"` + Email string `json:"user_email"` + AvatarURL *string `json:"avatar_url"` // Nullable avatar URL to avoid DB queries + jwt.RegisteredClaims +} + +// GenerateJWT creates a stateless JWT token for a user +// Changed: Input is now userID (and optional role), not sessionID +func GenerateJWT(userID uuid.UUID, name string, email string, avatarURL *string, secret string, expiresIn time.Duration) (string, error) { + claims := UserClaims{ + Name: name, + Email: email, + AvatarURL: avatarURL, + RegisteredClaims: jwt.RegisteredClaims{ + // Standard claim "Subject" is technically where UserID belongs, + // but having a typed UserID field is easier for Go type assertions. + Subject: userID.String(), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiresIn)), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: "realtime-collab", // Your app name + }, + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + return token.SignedString([]byte(secret)) +} + +// ValidateJWT parses the token and extracts the UserClaims +// Changed: Returns *UserClaims so you can access UserID and Role directly +func ValidateJWT(tokenString, secret string) (*UserClaims, error) { + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { + // Security Check: Always validate the signing algorithm + // to prevent "None" algorithm attacks. + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, errors.New("invalid signing method") + } + return []byte(secret), nil + }) + + if err != nil { + return nil, err + } + + // Type assertion to get our custom struct back + if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { + return claims, nil + } + + return nil, errors.New("invalid token claims") +} \ No newline at end of file diff --git a/backend/internal/auth/middleware.go b/backend/internal/auth/middleware.go new file mode 100644 index 0000000..00a740c --- /dev/null +++ b/backend/internal/auth/middleware.go @@ -0,0 +1,193 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" + + "github.com/M1ngdaXie/realtime-collab/internal/store" + "github.com/gin-gonic/gin" + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) + +type contextKey string + +const UserContextKey contextKey = "user" +const ContextUserIDKey = "user_id" + +// AuthMiddleware provides auth middleware +type AuthMiddleware struct { + store store.Store + jwtSecret string +} + +// NewAuthMiddleware creates a new auth middleware +func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware { + return &AuthMiddleware{ + store: store, + jwtSecret: jwtSecret, + } +} + +// RequireAuth middleware requires valid authentication +func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc { + return func(c *gin.Context) { + fmt.Println("🔒 RequireAuth: Starting authentication check") + + user, claims, err := m.getUserFromToken(c) + + fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err) + if claims != nil { + fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email) + } + + if err != nil || user == nil { + fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user) + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + c.Abort() + return + } + + // Note: Name and Email might be empty for old JWT tokens + if claims.Name == "" || claims.Email == "" { + fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n") + } + + fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user) + c.Set(ContextUserIDKey, user) + c.Set("user_email", claims.Email) + c.Set("user_name", claims.Name) + if claims.AvatarURL != nil { + c.Set("avatar_url", *claims.AvatarURL) + } + c.Next() + } +} + +// OptionalAuth middleware sets user if authenticated, but doesn't require it +func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc { + return func(c *gin.Context) { + user, claims, _ := m.getUserFromToken(c) + if user != nil { + c.Set(string(UserContextKey), user) + c.Set(ContextUserIDKey, user) + if claims != nil { + c.Set("user_email", claims.Email) + c.Set("user_name", claims.Name) + if claims.AvatarURL != nil { + c.Set("avatar_url", *claims.AvatarURL) + } + } + } + c.Next() + } +} + +// getUserFromToken parses the JWT and returns the UserID and the full Claims (for name/email) +// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error) +func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) { + authHeader := c.GetHeader("Authorization") + fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader) + + if authHeader == "" { + fmt.Println("⚠️ getUserFromToken: No Authorization header") + return nil, nil, nil + } + + parts := strings.Split(authHeader, " ") + if len(parts) != 2 || parts[0] != "Bearer" { + fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0]) + return nil, nil, nil + } + + tokenString := parts[1] + fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))]) + + token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { + // 必须要验证签名算法是 HMAC (HS256) + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + return []byte(m.jwtSecret), nil + }) + + if err != nil { + fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err) + return nil, nil, err + } + + // 2. 验证 Token 有效性并提取 Claims + if claims, ok := token.Claims.(*UserClaims); ok && token.Valid { + // 3. 把 String 类型的 Subject 转回 UUID + // 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String() + userID, err := uuid.Parse(claims.Subject) + if err != nil { + fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err) + return nil, nil, fmt.Errorf("invalid user ID in token") + } + + // 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email) + // 这一步完全没有查数据库,速度极快 + fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email) + return &userID, claims, nil + } + + fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid") + return nil, nil, fmt.Errorf("invalid token claims") +} + +// GetUserFromContext extracts user ID from context +func GetUserFromContext(c *gin.Context) *uuid.UUID { + // 修正点:使用和存入时完全一样的 Key + val, exists := c.Get(ContextUserIDKey) + fmt.Println("within getFromContext the id is ... ") + fmt.Println(val); + if !exists { + return nil + } + + // 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型) + uid, ok := val.(*uuid.UUID) + if !ok { + return nil + } + + return uid +} + +// ValidateToken validates a JWT token and returns user ID, name, and avatar URL from JWT claims +func (m *AuthMiddleware) ValidateToken(tokenString string) (*uuid.UUID, string, string, error) { + // Parse and validate JWT + claims, err := ValidateJWT(tokenString, m.jwtSecret) + if err != nil { + return nil, "", "", fmt.Errorf("invalid token: %w", err) + } + + // Parse user ID from claims + userID, err := uuid.Parse(claims.Subject) + if err != nil { + return nil, "", "", fmt.Errorf("invalid user ID in token: %w", err) + } + + // Get session from database by token (for revocation capability) + session, err := m.store.GetSessionByToken(context.Background(), tokenString) + if err != nil { + return nil, "", "", fmt.Errorf("session not found: %w", err) + } + + // Verify session UserID matches JWT Subject + if session.UserID != userID { + return nil, "", "", fmt.Errorf("session ID mismatch") + } + + // Extract avatar URL from claims (handle nil gracefully) + avatarURL := "" + if claims.AvatarURL != nil { + avatarURL = *claims.AvatarURL + } + + // Return user data from JWT claims - no DB query needed! + return &userID, claims.Name, avatarURL, nil +} diff --git a/backend/internal/auth/oauth.go b/backend/internal/auth/oauth.go new file mode 100644 index 0000000..5dad179 --- /dev/null +++ b/backend/internal/auth/oauth.go @@ -0,0 +1,32 @@ +package auth + +import ( + "golang.org/x/oauth2" + "golang.org/x/oauth2/github" + "golang.org/x/oauth2/google" +) + +// GetGoogleOAuthConfig returns Google OAuth2 config +func GetGoogleOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config { + return &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURL, + Scopes: []string{ + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/userinfo.profile", + }, + Endpoint: google.Endpoint, + } +} + +// GetGitHubOAuthConfig returns GitHub OAuth2 config +func GetGitHubOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config { + return &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURL, + Scopes: []string{"user:email"}, + Endpoint: github.Endpoint, + } +} diff --git a/backend/internal/handlers/auth.go b/backend/internal/handlers/auth.go new file mode 100644 index 0000000..a4a1243 --- /dev/null +++ b/backend/internal/handlers/auth.go @@ -0,0 +1,302 @@ +package handlers + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "os" + "time" + + "github.com/M1ngdaXie/realtime-collab/internal/auth" + "github.com/M1ngdaXie/realtime-collab/internal/models" + "github.com/M1ngdaXie/realtime-collab/internal/store" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "golang.org/x/oauth2" +) + +type AuthHandler struct { + store store.Store + googleConfig *oauth2.Config + githubConfig *oauth2.Config + jwtSecret string + frontendURL string +} + +func NewAuthHandler(store store.Store, jwtSecret, frontendURL string) *AuthHandler { + googleConfig := auth.GetGoogleOAuthConfig( + os.Getenv("GOOGLE_CLIENT_ID"), + os.Getenv("GOOGLE_CLIENT_SECRET"), + os.Getenv("GOOGLE_REDIRECT_URL"), + ) + + githubConfig := auth.GetGitHubOAuthConfig( + os.Getenv("GITHUB_CLIENT_ID"), + os.Getenv("GITHUB_CLIENT_SECRET"), + os.Getenv("GITHUB_REDIRECT_URL"), + ) + + return &AuthHandler{ + store: store, + googleConfig: googleConfig, + githubConfig: githubConfig, + jwtSecret: jwtSecret, + frontendURL: frontendURL, + } +} + +// GoogleLogin redirects to Google OAuth +func (h *AuthHandler) GoogleLogin(c *gin.Context) { + // Generate random state and set cookie + oauthState := generateStateOauthCookie(c.Writer) + url := h.googleConfig.AuthCodeURL(oauthState, oauth2.AccessTypeOffline) + c.Redirect(http.StatusTemporaryRedirect, url) +} + +// GoogleCallback handles Google OAuth callback +func (h *AuthHandler) GoogleCallback(c *gin.Context) { + oauthState, err := c.Cookie("oauthstate") + if err != nil || c.Query("state") != oauthState { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"}) + return + } + log.Println("Google callback state:", c.Query("state")) + // Exchange code for token + token, err := h.googleConfig.Exchange(context.Background(), c.Query("code")) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange token"}) + return + } + + // Get user info from Google + client := h.googleConfig.Client(context.Background(), token) + resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo") + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"}) + return + } + log.Println("Google user info response status:", resp.Status) + log.Println("Google user info response headers:", resp.Header) + defer resp.Body.Close() + + data, _ := io.ReadAll(resp.Body) + var userInfo struct { + ID string `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + Picture string `json:"picture"` + } + json.Unmarshal(data, &userInfo) + log.Println("Google user info:", userInfo) + // Upsert user in database + user, err := h.store.UpsertUser( + c.Request.Context(), + "google", + userInfo.ID, + userInfo.Email, + userInfo.Name, + &userInfo.Picture, + ) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) + return + } + + // Create session and JWT + jwt, err := h.createSessionAndJWT(c, user) + if err != nil { + fmt.Printf("❌ DATABASE ERROR: %v\n", err) + c.JSON(http.StatusInternalServerError, gin.H{ + "error": fmt.Sprintf("CreateSession Error: %v", err), + }) + return + } + + // Redirect to frontend with token + redirectURL := fmt.Sprintf("%s/auth/callback?token=%s", h.frontendURL, jwt) + c.Redirect(http.StatusTemporaryRedirect, redirectURL) +} + +// GithubLogin redirects to GitHub OAuth +func (h *AuthHandler) GithubLogin(c *gin.Context) { + url := h.githubConfig.AuthCodeURL("state", oauth2.AccessTypeOffline) + c.Redirect(http.StatusTemporaryRedirect, url) +} + +// GithubCallback handles GitHub OAuth callback +func (h *AuthHandler) GithubCallback(c *gin.Context) { + code := c.Query("code") + if code == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "No code provided"}) + return + } + + // Exchange code for token + token, err := h.githubConfig.Exchange(context.Background(), code) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange token"}) + return + } + + // Get user info from GitHub + client := h.githubConfig.Client(context.Background(), token) + + // Get user profile + resp, err := client.Get("https://api.github.com/user") + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"}) + return + } + defer resp.Body.Close() + + data, _ := io.ReadAll(resp.Body) + var userInfo struct { + ID int `json:"id"` + Login string `json:"login"` + Name string `json:"name"` + Email string `json:"email"` + AvatarURL string `json:"avatar_url"` + } + json.Unmarshal(data, &userInfo) + + // If email is not public, fetch it separately + if userInfo.Email == "" { + emailResp, _ := client.Get("https://api.github.com/user/emails") + if emailResp != nil { + defer emailResp.Body.Close() + emailData, _ := io.ReadAll(emailResp.Body) + var emails []struct { + Email string `json:"email"` + Primary bool `json:"primary"` + } + json.Unmarshal(emailData, &emails) + for _, e := range emails { + if e.Primary { + userInfo.Email = e.Email + break + } + } + } + } + + // Use login as name if name is empty + if userInfo.Name == "" { + userInfo.Name = userInfo.Login + } + + // Upsert user in database + user, err := h.store.UpsertUser( + c.Request.Context(), + "github", + fmt.Sprintf("%d", userInfo.ID), + userInfo.Email, + userInfo.Name, + &userInfo.AvatarURL, + ) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"}) + return + } + + // Create session and JWT + jwt, err := h.createSessionAndJWT(c, user) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"}) + return + } + + // Redirect to frontend with token + redirectURL := fmt.Sprintf("%s/auth/callback?token=%s", h.frontendURL, jwt) + c.Redirect(http.StatusTemporaryRedirect, redirectURL) +} + +// Me returns current user info +func (h *AuthHandler) Me(c *gin.Context) { + userID := auth.GetUserFromContext(c) + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + user, err := h.store.GetUserByID(c.Request.Context(), *userID) + if err != nil || user == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + + c.JSON(http.StatusOK, models.UserResponse{User: user}) +} + +// Logout invalidates the session +func (h *AuthHandler) Logout(c *gin.Context) { + authHeader := c.GetHeader("Authorization") + if authHeader == "" { + c.JSON(http.StatusOK, gin.H{"message": "Already logged out"}) + return + } + + // Extract token + token := "" + if len(authHeader) > 7 && authHeader[:7] == "Bearer " { + token = authHeader[7:] + } + + if token != "" { + h.store.DeleteSession(c.Request.Context(), token) + } + + c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"}) +} + +// Helper: create session and JWT +func (h *AuthHandler) createSessionAndJWT(c *gin.Context, user *models.User) (string, error) { + expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7 days + + // Generate JWT first (we need it for session) - now includes avatar URL + jwt, err := auth.GenerateJWT(user.ID, user.Name, user.Email, user.AvatarURL, h.jwtSecret, 7*24*time.Hour) + if err != nil { + return "", err + } + + // Create session in database + sessionID := uuid.New() + userAgent := c.GetHeader("User-Agent") + ipAddress := c.ClientIP() + _, err = h.store.CreateSession( + c.Request.Context(), + user.ID, + sessionID, + jwt, + expiresAt, + &userAgent, + &ipAddress, + ) + if err != nil { + return "", err + } + + return jwt, nil +} +func generateStateOauthCookie(w http.ResponseWriter) string { + b := make([]byte, 16) + rand.Read(b) + state := base64.URLEncoding.EncodeToString(b) + + cookie := http.Cookie{ + Name: "oauthstate", + Value: state, + Expires: time.Now().Add(10 * time.Minute), + HttpOnly: true, // Prevents JavaScript access (XSS protection) + Secure: false, // Must be false for http://localhost (set true in production) + SameSite: http.SameSiteLaxMode, // ✅ Allows same-site OAuth redirects + Path: "/", // ✅ Ensures cookie is sent to all backend paths + } + http.SetCookie(w, &cookie) + + return state +} diff --git a/backend/internal/handlers/document.go b/backend/internal/handlers/document.go index 212a508..ba52f29 100644 --- a/backend/internal/handlers/document.go +++ b/backend/internal/handlers/document.go @@ -1,8 +1,10 @@ package handlers import ( + "fmt" "net/http" + "github.com/M1ngdaXie/realtime-collab/internal/auth" "github.com/M1ngdaXie/realtime-collab/internal/models" "github.com/M1ngdaXie/realtime-collab/internal/store" "github.com/gin-gonic/gin" @@ -10,135 +12,199 @@ import ( ) type DocumentHandler struct { - store *store.Store + store *store.PostgresStore } - func NewDocumentHandler(s *store.Store) *DocumentHandler { + func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler { return &DocumentHandler{store: s} } - // CreateDocument creates a new document - func (h *DocumentHandler) CreateDocument(c *gin.Context) { - var req models.CreateDocumentRequest - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - // Validate document type - if req.Type != models.DocumentTypeEditor && req.Type != models.DocumentTypeKanban { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document type"}) - return - } + // CreateDocument creates a new document (requires auth) +func (h *DocumentHandler) CreateDocument(c *gin.Context) { + fmt.Println("getting userId right now.... ") + userID := auth.GetUserFromContext(c) + fmt.Println(userID) + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } - doc, err := h.store.CreateDocument(req.Name, req.Type) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + var req models.CreateDocumentRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } - c.JSON(http.StatusCreated, doc) - } + // Create document with owner_id + doc, err := h.store.CreateDocumentWithOwner(req.Name, req.Type, userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to create document: %v", err)}) + return + } + + c.JSON(http.StatusCreated, doc) +} - // ListDocuments returns all documents func (h *DocumentHandler) ListDocuments(c *gin.Context) { - documents, err := h.store.ListDocuments() - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + userID := auth.GetUserFromContext(c) - if documents == nil { - documents = []models.Document{} - } + var docs []models.Document + var err error + + if userID != nil { + // Authenticated: show owned + shared documents + docs, err = h.store.ListUserDocuments(c.Request.Context(), *userID) + } else { + c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("we dont know you: %v", err)}) + } + + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list documents"}) + return + } + + c.JSON(http.StatusOK, models.DocumentListResponse{ + Documents: docs, + Total: len(docs), + }) +} - c.JSON(http.StatusOK, models.DocumentListResponse{ - Documents: documents, - Total: len(documents), - }) - } - // GetDocument returns a single document func (h *DocumentHandler) GetDocument(c *gin.Context) { - idStr := c.Param("id") - id, err := uuid.Parse(idStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"}) - return - } + id, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) + return + } - doc, err := h.store.GetDocument(id) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "document not found"}) - return - } + userID := auth.GetUserFromContext(c) - c.JSON(http.StatusOK, doc) - } + // Check permission if authenticated + if userID != nil { + canView, err := h.store.CanViewDocument(c.Request.Context(), id, *userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"}) + return + } + if !canView { + c.JSON(http.StatusForbidden, gin.H{"error": "Access denied"}) + return + } + }else{ + c.JSON("this file is not public") + return + } + doc, err := h.store.GetDocument(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Document not found"}) + return + } + + c.JSON(http.StatusOK, doc) +} // GetDocumentState returns the Yjs state for a document - func (h *DocumentHandler) GetDocumentState(c *gin.Context) { - idStr := c.Param("id") - id, err := uuid.Parse(idStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"}) - return - } + // GetDocumentState retrieves document state (requires view permission) +func (h *DocumentHandler) GetDocumentState(c *gin.Context) { + id, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) + return + } - doc, err := h.store.GetDocument(id) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "document not found"}) - return - } + userID := auth.GetUserFromContext(c) - // Return binary state - if doc.YjsState == nil { - c.Data(http.StatusOK, "application/octet-stream", []byte{}) - return - } + // Check permission if authenticated + if userID != nil { + canView, err := h.store.CanViewDocument(c.Request.Context(), id, *userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"}) + return + } + if !canView { + c.JSON(http.StatusForbidden, gin.H{"error": "Access denied"}) + return + } + } - c.Data(http.StatusOK, "application/octet-stream", doc.YjsState) - } + doc, err := h.store.GetDocument(id) + if err != nil { + c.JSON(http.StatusNotFound, gin.H{"error": "Document not found"}) + return + } - // UpdateDocumentState updates the Yjs state for a document - func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) { - idStr := c.Param("id") - id, err := uuid.Parse(idStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"}) - return - } + c.Data(http.StatusOK, "application/octet-stream", doc.YjsState) +} - // Read binary body - state, err := c.GetRawData() - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"}) - return - } + // UpdateDocumentState updates document state (requires edit permission) +func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) { + id, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) + return + } - err = h.store.UpdateDocumentState(id, state) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + userID := auth.GetUserFromContext(c) + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } - c.JSON(http.StatusOK, gin.H{"message": "state updated successfully"}) - } + // Check edit permission + canEdit, err := h.store.CanEditDocument(c.Request.Context(), id, *userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"}) + return + } + if !canEdit { + c.JSON(http.StatusForbidden, gin.H{"error": "Edit access denied"}) + return + } - // DeleteDocument deletes a document - func (h *DocumentHandler) DeleteDocument(c *gin.Context) { - idStr := c.Param("id") - id, err := uuid.Parse(idStr) - if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"}) - return - } + var req models.UpdateStateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } - err = h.store.DeleteDocument(id) - if err != nil { - c.JSON(http.StatusNotFound, gin.H{"error": "document not found"}) - return - } + if err := h.store.UpdateDocumentState(id, req.State); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update state"}) + return + } - c.JSON(http.StatusOK, gin.H{"message": "document deleted successfully"}) - } + c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"}) +} + + // DeleteDocument deletes a document (owner only) +func (h *DocumentHandler) DeleteDocument(c *gin.Context) { + id, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) + return + } + + userID := auth.GetUserFromContext(c) + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + // Check ownership + isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), id, *userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"}) + return + } + if !isOwner { + c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can delete documents"}) + return + } + + if err := h.store.DeleteDocument(id); err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete document"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"}) +} \ No newline at end of file diff --git a/backend/internal/handlers/share.go b/backend/internal/handlers/share.go new file mode 100644 index 0000000..2534923 --- /dev/null +++ b/backend/internal/handlers/share.go @@ -0,0 +1,286 @@ +package handlers + +import ( + "fmt" + "net/http" + "os" // Add this + + "github.com/M1ngdaXie/realtime-collab/internal/auth" + "github.com/M1ngdaXie/realtime-collab/internal/models" + "github.com/M1ngdaXie/realtime-collab/internal/store" + "github.com/gin-gonic/gin" + "github.com/google/uuid" +) + +type ShareHandler struct { + store store.Store +} + +func NewShareHandler(store store.Store) *ShareHandler { + return &ShareHandler{store: store} +} + +// CreateShare creates a new document share +func (h *ShareHandler) CreateShare(c *gin.Context) { + userID := auth.GetUserFromContext(c) + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + documentID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) + return + } + + // Check if user is owner + isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"}) + return + } + if !isOwner { + c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can share documents"}) + return + } + + var req models.CreateShareRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + // Get user by email + targetUser, err := h.store.GetUserByEmail(c.Request.Context(), req.UserEmail) + if err != nil || targetUser == nil { + c.JSON(http.StatusNotFound, gin.H{"error": "User not found"}) + return + } + + // Create share + share, err := h.store.CreateDocumentShare( + c.Request.Context(), + documentID, + targetUser.ID, + req.Permission, + userID, + ) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create share"}) + return + } + + c.JSON(http.StatusCreated, share) +} + +// ListShares lists all shares for a document +func (h *ShareHandler) ListShares(c *gin.Context) { + userID := auth.GetUserFromContext(c) + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + documentID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) + return + } + + // Check if user is owner + isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"}) + return + } + if !isOwner { + c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can view shares"}) + return + } + + shares, err := h.store.ListDocumentShares(c.Request.Context(), documentID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list shares"}) + return + } + + c.JSON(http.StatusOK, models.ShareListResponse{Shares: shares}) +} + +// DeleteShare removes a share +func (h *ShareHandler) DeleteShare(c *gin.Context) { + userID := auth.GetUserFromContext(c) + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + documentID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) + return + } + + targetUserID, err := uuid.Parse(c.Param("userId")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"}) + return + } + + // Check if user is owner + isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"}) + return + } + if !isOwner { + c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can delete shares"}) + return + } + + err = h.store.DeleteDocumentShare(c.Request.Context(), documentID, targetUserID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete share"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "Share deleted successfully"}) +} +// CreateShareLink generates a public share link +func (h *ShareHandler) CreateShareLink(c *gin.Context) { + documentID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) + return + } + + userID := auth.GetUserFromContext(c) + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + // Check if user is owner + isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"}) + return + } + if !isOwner { + c.JSON(http.StatusForbidden, gin.H{"error": "Only document owner can create share links"}) + return + } + + // Parse request body + var req struct { + Permission string `json:"permission" binding:"required,oneof=view edit"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Permission must be 'view' or 'edit'"}) + return + } + + // Generate share token + token, err := h.store.GenerateShareToken(c.Request.Context(), documentID, req.Permission) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate share link"}) + return + } + + // Get frontend URL from env + frontendURL := os.Getenv("FRONTEND_URL") + if frontendURL == "" { + frontendURL = "http://localhost:5173" + } + + shareURL := fmt.Sprintf("%s/editor/%s?share=%s", frontendURL, documentID.String(), token) + + c.JSON(http.StatusOK, gin.H{ + "url": shareURL, + "token": token, + "permission": req.Permission, + }) +} + +// GetShareLink retrieves the current public share link +func (h *ShareHandler) GetShareLink(c *gin.Context) { + documentID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) + return + } + + userID := auth.GetUserFromContext(c) + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + // Check if user is owner + isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"}) + return + } + if !isOwner { + c.JSON(http.StatusForbidden, gin.H{"error": "Only document owner can view share links"}) + return + } + + token, exists, err := h.store.GetShareToken(c.Request.Context(), documentID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get share link"}) + return + } + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "No public share link exists"}) + return + } + + frontendURL := os.Getenv("FRONTEND_URL") + if frontendURL == "" { + frontendURL = "http://localhost:5173" + } + + shareURL := fmt.Sprintf("%s/editor/%s?share=%s", frontendURL, documentID.String(), token) + + c.JSON(http.StatusOK, gin.H{ + "url": shareURL, + "token": token, + }) +} + +// RevokeShareLink removes the public share link +func (h *ShareHandler) RevokeShareLink(c *gin.Context) { + documentID, err := uuid.Parse(c.Param("id")) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) + return + } + + userID := auth.GetUserFromContext(c) + if userID == nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) + return + } + + // Check if user is owner + isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"}) + return + } + if !isOwner { + c.JSON(http.StatusForbidden, gin.H{"error": "Only document owner can revoke share links"}) + return + } + + err = h.store.RevokeShareToken(c.Request.Context(), documentID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to revoke share link"}) + return + } + + c.JSON(http.StatusOK, gin.H{"message": "Share link revoked"}) +} \ No newline at end of file diff --git a/backend/internal/handlers/websocket.go b/backend/internal/handlers/websocket.go index d5ce33d..07cbf39 100644 --- a/backend/internal/handlers/websocket.go +++ b/backend/internal/handlers/websocket.go @@ -3,57 +3,147 @@ package handlers import ( "log" "net/http" + "os" + "github.com/M1ngdaXie/realtime-collab/internal/auth" "github.com/M1ngdaXie/realtime-collab/internal/hub" + "github.com/M1ngdaXie/realtime-collab/internal/store" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/gorilla/websocket" ) - var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - // Allow all origins for development - // TODO: Restrict in production - return true - }, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + // Check origin against allowed origins from environment + allowedOrigins := os.Getenv("ALLOWED_ORIGINS") + if allowedOrigins == "" { + // Default for development + origin := r.Header.Get("Origin") + return origin == "http://localhost:5173" || origin == "http://localhost:3000" + } + // Production: validate against ALLOWED_ORIGINS + // TODO: Parse and validate origin + return true + }, } - type WebSocketHandler struct { - hub *hub.Hub - } +type WebSocketHandler struct { + hub *hub.Hub + store store.Store +} - func NewWebSocketHandler(h *hub.Hub) *WebSocketHandler { - return &WebSocketHandler{hub: h} - } +func NewWebSocketHandler(h *hub.Hub, s store.Store) *WebSocketHandler { + return &WebSocketHandler{ + hub: h, + store: s, + } +} - func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context){ +func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) { roomID := c.Param("roomId") - - if(roomID == ""){ + if roomID == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "roomId is required"}) return } - conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + + // Parse document ID + documentID, err := uuid.Parse(roomID) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upgrade to WebSocket"}) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"}) return } - - // Create a new client - clientID := uuid.New().String() - client := hub.NewClient(clientID, conn, wsh.hub, roomID) - // Register client with hub - wsh.hub.Register <- client + // Try to authenticate via JWT token or share token + var userID *uuid.UUID + var userName string + var userAvatar *string + authenticated := false - // Start read and write pumps in separate goroutines - go client.WritePump() - go client.ReadPump() + // Check for JWT token in query parameter + jwtToken := c.Query("token") + if jwtToken != "" { + // Validate JWT and get user data from token claims (no DB query!) + jwtSecret := os.Getenv("JWT_SECRET") + if jwtSecret == "" { + log.Println("JWT_SECRET not configured") + c.JSON(http.StatusInternalServerError, gin.H{"error": "Server configuration error"}) + return + } - log.Printf("WebSocket connection established for client %s in room %s", clientID, roomID) - } + authMiddleware := auth.NewAuthMiddleware(wsh.store, jwtSecret) + uid, name, avatar, err := authMiddleware.ValidateToken(jwtToken) + if err == nil && uid != nil { + // User data comes directly from JWT claims - no DB query needed! + userID = uid + userName = name + if avatar != "" { + userAvatar = &avatar + } + authenticated = true + } + } - + // If not authenticated via JWT, check for share token + if !authenticated { + shareToken := c.Query("share") + if shareToken != "" { + // Validate share token + valid, err := wsh.store.ValidateShareToken(c.Request.Context(), documentID, shareToken) + if err != nil { + log.Printf("Error validating share token: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to validate share token"}) + return + } + if !valid { + c.JSON(http.StatusForbidden, gin.H{"error": "Invalid or expired share token"}) + return + } + // Share token is valid, allow connection with anonymous user + userName = "Anonymous" + authenticated = true + } + } + + // If still not authenticated, reject connection + if !authenticated { + c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required. Provide 'token' or 'share' query parameter"}) + return + } + + // If authenticated with JWT, check document permissions + if userID != nil { + canView, err := wsh.store.CanViewDocument(c.Request.Context(), documentID, *userID) + if err != nil { + log.Printf("Error checking permissions: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"}) + return + } + if !canView { + c.JSON(http.StatusForbidden, gin.H{"error": "You don't have permission to access this document"}) + return + } + } + + // Upgrade connection + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + log.Printf("Failed to upgrade connection: %v", err) + return + } + + // Create client with user information + clientID := uuid.New().String() + client := hub.NewClient(clientID, userID, userName, userAvatar, conn, wsh.hub, roomID) + + // Register client + wsh.hub.Register <- client + + // Start goroutines + go client.WritePump() + go client.ReadPump() + + log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID) +} diff --git a/backend/internal/hub/hub.go b/backend/internal/hub/hub.go index e9dee91..97f535d 100644 --- a/backend/internal/hub/hub.go +++ b/backend/internal/hub/hub.go @@ -3,7 +3,9 @@ package hub import ( "log" "sync" + "time" + "github.com/google/uuid" "github.com/gorilla/websocket" ) @@ -14,11 +16,20 @@ type Message struct { } type Client struct { - ID string - Conn *websocket.Conn - send chan []byte - hub *Hub - roomID string + ID string + UserID *uuid.UUID // Authenticated user ID (nil for public share access) + UserName string // User's display name for presence + UserAvatar *string // User's avatar URL for presence + Conn *websocket.Conn + send chan []byte + sendMu sync.Mutex + sendClosed bool + hub *Hub + roomID string + mutex sync.Mutex + unregisterOnce sync.Once + failureCount int + failureMu sync.Mutex } type Room struct { ID string @@ -74,54 +85,99 @@ func (h *Hub) registerClient(client *Client) { log.Printf("Client %s joined room %s (total clients: %d)", client.ID, client.roomID, len(room.clients)) } func (h *Hub) unregisterClient(client *Client) { - h.mu.Lock() - defer h.mu.Unlock() + h.mu.Lock() + defer h.mu.Unlock() - room, exists := h.rooms[client.roomID] - if !exists { - log.Printf("Room %s does not exist for client %s", client.roomID, client.ID) - return - } - room.mu.Lock() - if _, ok := room.clients[client]; ok { - delete(room.clients, client) - close(client.send) - log.Printf("Client %s disconnected from room %s", client.ID, client.roomID) - } + room, exists := h.rooms[client.roomID] + if !exists { + log.Printf("Room %s does not exist for client %s", client.roomID, client.ID) + return + } - room.mu.Unlock() - log.Printf("Client %s left room %s (total clients: %d)", client.ID, client.roomID, len(room.clients)) + room.mu.Lock() + defer room.mu.Unlock() - if len(room.clients) == 0 { - delete(h.rooms, client.roomID) - log.Printf("Deleted empty room with ID: %s", client.roomID) - } + if _, ok := room.clients[client]; ok { + delete(room.clients, client) + + // Safely close send channel exactly once + client.sendMu.Lock() + if !client.sendClosed { + close(client.send) + client.sendClosed = true + } + client.sendMu.Unlock() + + log.Printf("Client %s disconnected from room %s (total clients: %d)", + client.ID, client.roomID, len(room.clients)) + } + + if len(room.clients) == 0 { + delete(h.rooms, client.roomID) + log.Printf("Deleted empty room with ID: %s", client.roomID) + } } + +const ( + writeWait = 10 * time.Second + pongWait = 60 * time.Second + pingPeriod = (pongWait * 9) / 10 // 54 seconds + maxSendFailures = 5 +) + func (h *Hub) broadcastMessage(message *Message) { - h.mu.RLock() - room, exists := h.rooms[message.RoomID] - h.mu.RUnlock() - if !exists { - log.Printf("Room %s does not exist for broadcasting", message.RoomID) - return - } + h.mu.RLock() + room, exists := h.rooms[message.RoomID] + h.mu.RUnlock() + if !exists { + log.Printf("Room %s does not exist for broadcasting", message.RoomID) + return + } - room.mu.RLock() - defer room.mu.RUnlock() - for client := range room.clients { - if client != message.sender { - select { - case client.send <- message.Data: - default: - log.Printf("Failed to send to client %s (channel full)", client.ID) - } - } - } + room.mu.RLock() + defer room.mu.RUnlock() + + for client := range room.clients { + if client != message.sender { + select { + case client.send <- message.Data: + // Success - reset failure count + client.failureMu.Lock() + client.failureCount = 0 + client.failureMu.Unlock() + + default: + // Failed - increment failure count + client.failureMu.Lock() + client.failureCount++ + currentFailures := client.failureCount + client.failureMu.Unlock() + + log.Printf("Failed to send to client %s (channel full, failures: %d/%d)", + client.ID, currentFailures, maxSendFailures) + + // Disconnect if threshold exceeded + if currentFailures >= maxSendFailures { + log.Printf("Client %s exceeded max send failures, disconnecting", client.ID) + go func(c *Client) { + c.unregister() + c.Conn.Close() + }(client) + } + } + } + } } + func (c *Client) ReadPump() { + c.Conn.SetReadDeadline(time.Now().Add(pongWait)) + c.Conn.SetPongHandler(func(string) error { + c.Conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) defer func() { - c.hub.Unregister <- c + c.unregister() c.Conn.Close() }() for { @@ -141,24 +197,54 @@ func (c *Client) ReadPump() { } func (c *Client) WritePump() { - defer func() { - c.Conn.Close() - }() - for message := range c.send { - err := c.Conn.WriteMessage(websocket.BinaryMessage, message) - if err != nil { - log.Printf("Error writing message to client %s: %v", c.ID, err) - break - } - } + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.unregister() // NEW: Now WritePump also unregisters + c.Conn.Close() + }() + + for { + select { + case message, ok := <-c.send: + c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + // Hub closed the channel + c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + err := c.Conn.WriteMessage(websocket.BinaryMessage, message) + if err != nil { + log.Printf("Error writing message to client %s: %v", c.ID, err) + return + } + + case <-ticker.C: + c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { + log.Printf("Ping failed for client %s: %v", c.ID, err) + return + } + } + } } -func NewClient(id string, conn *websocket.Conn, hub *Hub, roomID string) *Client { + +func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string, conn *websocket.Conn, hub *Hub, roomID string) *Client { return &Client{ - ID: id, - Conn: conn, - send: make(chan []byte, 256), - hub: hub, - roomID: roomID, + ID: id, + UserID: userID, + UserName: userName, + UserAvatar: userAvatar, + Conn: conn, + send: make(chan []byte, 256), + hub: hub, + roomID: roomID, } +} +func (c *Client) unregister() { + c.unregisterOnce.Do(func() { + c.hub.Unregister <- c + }) } \ No newline at end of file diff --git a/backend/internal/models/document.go b/backend/internal/models/document.go index 7db851b..7fe6e10 100644 --- a/backend/internal/models/document.go +++ b/backend/internal/models/document.go @@ -14,14 +14,17 @@ const ( ) type Document struct { - ID uuid.UUID `json:"id"` - Name string `json:"name"` - Type DocumentType `json:"type"` - YjsState []byte `json:"-"` // Don't expose binary data in JSON - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` + ID uuid.UUID `json:"id"` + Name string `json:"name"` + Type DocumentType `json:"type"` + YjsState []byte `json:"-"` + OwnerID *uuid.UUID `json:"owner_id"` // NEW + Is_Public bool `json:"is_public"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` } + type CreateDocumentRequest struct { Name string `json:"name" binding:"required"` Type DocumentType `json:"type" binding:"required"` diff --git a/backend/internal/models/share.go b/backend/internal/models/share.go new file mode 100644 index 0000000..9690f17 --- /dev/null +++ b/backend/internal/models/share.go @@ -0,0 +1,30 @@ +package models + +import ( + "time" + + "github.com/google/uuid" +) + +type DocumentShare struct { + ID uuid.UUID `json:"id"` + DocumentID uuid.UUID `json:"document_id"` + UserID uuid.UUID `json:"user_id"` + Permission string `json:"permission"` // "view" or "edit" + CreatedAt time.Time `json:"created_at"` + CreatedBy *uuid.UUID `json:"created_by"` +} + +type CreateShareRequest struct { + UserEmail string `json:"user_email" binding:"required"` + Permission string `json:"permission" binding:"required,oneof=view edit"` +} + +type ShareListResponse struct { + Shares []DocumentShareWithUser `json:"shares"` +} + +type DocumentShareWithUser struct { + DocumentShare + User User `json:"user"` +} diff --git a/backend/internal/models/user.go b/backend/internal/models/user.go new file mode 100644 index 0000000..3a1bbc6 --- /dev/null +++ b/backend/internal/models/user.go @@ -0,0 +1,48 @@ +package models + +import ( + "time" + + "github.com/google/uuid" +) + +type User struct { + ID uuid.UUID `json:"id"` + Email string `json:"email"` + Name string `json:"name"` + AvatarURL *string `json:"avatar_url"` + Provider string `json:"provider"` + ProviderUserID string `json:"-"` // Don't expose + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + LastLoginAt *time.Time `json:"last_login_at"` +} + +type Session struct { + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + TokenHash string `json:"-"` // SHA-256 hash of JWT + ExpiresAt time.Time `json:"expires_at"` + CreatedAt time.Time `json:"created_at"` + UserAgent *string `json:"user_agent"` + IPAddress *string `json:"ip_address"` +} + +type OAuthToken struct { + ID uuid.UUID `json:"id"` + UserID uuid.UUID `json:"user_id"` + Provider string `json:"provider"` + AccessToken string `json:"-"` // Don't expose + RefreshToken *string `json:"-"` + TokenType string `json:"token_type"` + ExpiresAt time.Time `json:"expires_at"` + Scope *string `json:"scope"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// Response for /auth/me endpoint +type UserResponse struct { + User *User `json:"user"` + Token string `json:"token,omitempty"` +} diff --git a/backend/internal/store/postgres.go b/backend/internal/store/postgres.go index 003112f..1fc712e 100644 --- a/backend/internal/store/postgres.go +++ b/backend/internal/store/postgres.go @@ -1,6 +1,7 @@ package store import ( + "context" "database/sql" "fmt" "time" @@ -10,11 +11,49 @@ import ( _ "github.com/lib/pq" // PostgreSQL driver ) -type Store struct{ - db *sql.DB +// Store interface defines all database operations +type Store interface { + // Document operations + CreateDocument(name string, docType models.DocumentType) (*models.Document, error) + CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS + GetDocument(id uuid.UUID) (*models.Document, error) + ListDocuments() ([]models.Document, error) + ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS + UpdateDocumentState(id uuid.UUID, state []byte) error + DeleteDocument(id uuid.UUID) error + + // User operations + UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) + GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) + GetUserByEmail(ctx context.Context, email string) (*models.User, error) + + // Session operations + CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) + GetSessionByToken(ctx context.Context, token string) (*models.Session, error) + DeleteSession(ctx context.Context, token string) error + CleanupExpiredSessions(ctx context.Context) error + + // Share operations + CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, error) + ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) + DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error + CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) + CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) + IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) + GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) + ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error) + RevokeShareToken(ctx context.Context, documentID uuid.UUID) error + GetShareToken(ctx context.Context, documentID uuid.UUID) (string, bool, error) + + Close() error } -func NewStore(databaseUrl string) (*Store, error) { + +type PostgresStore struct { + db *sql.DB +} + +func NewPostgresStore(databaseUrl string) (*PostgresStore, error) { db, error := sql.Open("postgres", databaseUrl) if error != nil { return nil, error @@ -25,14 +64,14 @@ func NewStore(databaseUrl string) (*Store, error) { db.SetMaxOpenConns(25) db.SetMaxIdleConns(5) db.SetConnMaxLifetime(5 * time.Minute) - return &Store{db: db}, nil + return &PostgresStore{db: db}, nil } -func (s *Store) Close() error { +func (s *PostgresStore) Close() error { return s.db.Close() } -func (s *Store) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) { +func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) { doc := &models.Document{ ID: uuid.New(), Name: name, @@ -62,7 +101,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model } // GetDocument retrieves a document by ID - func (s *Store) GetDocument(id uuid.UUID) (*models.Document, error) { + func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) { doc := &models.Document{} query := ` @@ -92,7 +131,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model // ListDocuments retrieves all documents - func (s *Store) ListDocuments() ([]models.Document, error) { + func (s *PostgresStore) ListDocuments() ([]models.Document, error) { query := ` SELECT id, name, type, created_at, updated_at FROM documents @@ -118,7 +157,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model return documents, nil } - func (s *Store) UpdateDocumentState(id uuid.UUID, state []byte) error { + func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error { query := ` UPDATE documents SET yjs_state = $1, updated_at = $2 @@ -142,7 +181,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model return nil } - func (s *Store) DeleteDocument(id uuid.UUID) error { + func (s *PostgresStore) DeleteDocument(id uuid.UUID) error { query := `DELETE FROM documents WHERE id = $1` result, err := s.db.Exec(query, id) @@ -162,3 +201,88 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model return nil } +// CreateDocumentWithOwner creates a new document with owner +func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) { + // 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错) + if docType == "" { + docType = models.DocumentTypeEditor // Default to editor instead of invalid "text" + } + + // Validate that docType is one of the allowed values + if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban { + return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType) + } + + doc := &models.Document{ + ID: uuid.New(), + Name: name, + Type: docType, + YjsState: []byte{}, // 这里初始化了空字节 + OwnerID: ownerID, + Is_Public: false, // 显式设置默认值 + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + // 2. 补全了 yjs_state 和 is_public + query := ` + INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at + ` + + // 3. Scan 的时候也要对应加上 + err := s.db.QueryRow(query, + doc.ID, + doc.Name, + doc.Type, + doc.OwnerID, + doc.YjsState, // $5 + doc.Is_Public, // $6 + doc.CreatedAt, + doc.UpdatedAt, + ).Scan( + &doc.ID, + &doc.Name, + &doc.Type, + &doc.OwnerID, + &doc.YjsState, // Scan 回来 + &doc.Is_Public, // Scan 回来 + &doc.CreatedAt, + &doc.UpdatedAt, + ) + + if err != nil { + return nil, fmt.Errorf("failed to create document: %w", err) + } + return doc, nil +} + +// ListUserDocuments lists documents owned by or shared with a user +func (s *PostgresStore) ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) { + query := ` + SELECT DISTINCT d.id, d.name, d.type, d.owner_id, d.created_at, d.updated_at + FROM documents d + LEFT JOIN document_shares ds ON d.id = ds.document_id + WHERE d.owner_id = $1 OR ds.user_id = $1 + ORDER BY d.created_at DESC + ` + + rows, err := s.db.QueryContext(ctx, query, userID) + if err != nil { + return nil, fmt.Errorf("failed to list user documents: %w", err) + } + defer rows.Close() + + var documents []models.Document + for rows.Next() { + var doc models.Document + err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.OwnerID, &doc.CreatedAt, &doc.UpdatedAt) + if err != nil { + return nil, fmt.Errorf("failed to scan document: %w", err) + } + documents = append(documents, doc) + } + + return documents, nil +} diff --git a/backend/internal/store/session.go b/backend/internal/store/session.go new file mode 100644 index 0000000..4e21686 --- /dev/null +++ b/backend/internal/store/session.go @@ -0,0 +1,88 @@ +package store + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "time" + + "github.com/M1ngdaXie/realtime-collab/internal/models" + "github.com/google/uuid" +) + +// CreateSession creates a new session +func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) { + // Hash the token before storing + hash := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(hash[:]) + + // 【修改点 1】: 在 SQL 里显式加上 id 字段 + // 注意:$1 变成了 id,后面的参数序号全部要顺延 (+1) + query := ` + INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address + ` + + var session models.Session + // 【修改点 2】: 在参数列表的最前面加上 sessionID + // 现在的对应关系: + // $1 -> sessionID + // $2 -> userID + // $3 -> tokenHash + // ... + err := s.db.QueryRowContext(ctx, query, + sessionID, // <--- 这里!把它传进去! + userID, + tokenHash, + expiresAt, + userAgent, + ipAddress, + ).Scan( + &session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt, + &session.CreatedAt, &session.UserAgent, &session.IPAddress, + ) + if err != nil { + return nil, err + } + + return &session, nil +} + +// GetSessionByToken retrieves session by JWT token +func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) { + hash := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(hash[:]) + + query := ` + SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address + FROM sessions + WHERE token_hash = $1 AND expires_at > NOW() + ` + + var session models.Session + err := s.db.QueryRowContext(ctx, query, tokenHash).Scan( + &session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt, + &session.CreatedAt, &session.UserAgent, &session.IPAddress, + ) + if err != nil { + return nil, err + } + + return &session, nil +} + +// DeleteSession deletes a session (logout) +func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error { + hash := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(hash[:]) + + _, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash) + return err +} + +// CleanupExpiredSessions removes expired sessions +func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error { + _, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()") + return err +} diff --git a/backend/internal/store/share.go b/backend/internal/store/share.go new file mode 100644 index 0000000..9477870 --- /dev/null +++ b/backend/internal/store/share.go @@ -0,0 +1,193 @@ +package store + +import ( + "context" + "crypto/rand" + "database/sql" + "encoding/base64" + "fmt" + + "github.com/M1ngdaXie/realtime-collab/internal/models" + "github.com/google/uuid" +) + +// CreateDocumentShare creates a new share +func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, error) { + query := ` + INSERT INTO document_shares (document_id, user_id, permission, created_by) + VALUES ($1, $2, $3, $4) + ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission + RETURNING id, document_id, user_id, permission, created_at, created_by + ` + + var share models.DocumentShare + err := s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan( + &share.ID, &share.DocumentID, &share.UserID, &share.Permission, + &share.CreatedAt, &share.CreatedBy, + ) + if err != nil { + return nil, err + } + + return &share, nil +} + +// ListDocumentShares lists all shares for a document +func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) { + query := ` + SELECT + ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by, + u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at + FROM document_shares ds + JOIN users u ON ds.user_id = u.id + WHERE ds.document_id = $1 + ORDER BY ds.created_at DESC + ` + + rows, err := s.db.QueryContext(ctx, query, documentID) + if err != nil { + return nil, err + } + defer rows.Close() + + var shares []models.DocumentShareWithUser + for rows.Next() { + var share models.DocumentShareWithUser + err := rows.Scan( + &share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy, + &share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider, + &share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt, + ) + if err != nil { + return nil, err + } + shares = append(shares, share) + } + + return shares, nil +} + +// DeleteDocumentShare deletes a share +func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error { + _, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID) + return err +} + +// CanViewDocument checks if user can view document (owner OR has any share) +func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) { + query := ` + SELECT EXISTS( + SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2 + UNION + SELECT 1 FROM document_shares WHERE document_id = $1 AND user_id = $2 + ) + ` + + var canView bool + err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView) + return canView, err +} + +// CanEditDocument checks if user can edit document (owner OR has edit share) +func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) { + query := ` + SELECT EXISTS( + SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2 + UNION + SELECT 1 FROM document_shares WHERE document_id = $1 AND user_id = $2 AND permission = 'edit' + ) + ` + + var canEdit bool + err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit) + return canEdit, err +} + +// IsDocumentOwner checks if user is the owner +func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) { + query := `SELECT owner_id = $2 FROM documents WHERE id = $1` + + var isOwner bool + err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner) + if err == sql.ErrNoRows { + return false, nil + } + return isOwner, err +} +func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) { + // Generate random 32-byte token + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + return "", fmt.Errorf("failed to generate token: %w", err) + } + token := base64.URLEncoding.EncodeToString(tokenBytes) + + // Update document with share token + query := ` + UPDATE documents + SET share_token = $1, is_public = true, updated_at = NOW() + WHERE id = $2 + RETURNING share_token + ` + + var shareToken string + err := s.db.QueryRowContext(ctx, query, token, documentID).Scan(&shareToken) + if err != nil { + return "", fmt.Errorf("failed to set share token: %w", err) + } + + return shareToken, nil +} + +// ValidateShareToken checks if a share token is valid for a document +func (s *PostgresStore) ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error) { + query := ` + SELECT EXISTS( + SELECT 1 FROM documents + WHERE id = $1 AND share_token = $2 AND is_public = true + ) + ` + + var exists bool + err := s.db.QueryRowContext(ctx, query, documentID, token).Scan(&exists) + if err != nil { + return false, fmt.Errorf("failed to validate share token: %w", err) + } + + return exists, nil +} + +// RevokeShareToken removes the public share link from a document +func (s *PostgresStore) RevokeShareToken(ctx context.Context, documentID uuid.UUID) error { + query := ` + UPDATE documents + SET share_token = NULL, is_public = false, updated_at = NOW() + WHERE id = $1 + ` + + _, err := s.db.ExecContext(ctx, query, documentID) + if err != nil { + return fmt.Errorf("failed to revoke share token: %w", err) + } + + return nil +} + +// GetShareToken retrieves the current share token for a document (if exists) +func (s *PostgresStore) GetShareToken(ctx context.Context, documentID uuid.UUID) (string, bool, error) { + query := ` + SELECT share_token FROM documents + WHERE id = $1 AND is_public = true AND share_token IS NOT NULL + ` + + var token string + err := s.db.QueryRowContext(ctx, query, documentID).Scan(&token) + if err == sql.ErrNoRows { + return "", false, nil + } + if err != nil { + return "", false, fmt.Errorf("failed to get share token: %w", err) + } + + return token, true, nil +} \ No newline at end of file diff --git a/backend/internal/store/user.go b/backend/internal/store/user.go new file mode 100644 index 0000000..18e7fd4 --- /dev/null +++ b/backend/internal/store/user.go @@ -0,0 +1,82 @@ +package store + +import ( + "context" + "database/sql" + "fmt" + + "github.com/M1ngdaXie/realtime-collab/internal/models" + "github.com/google/uuid" +) + +// UpsertUser creates or updates user from OAuth profile +func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) { + query := ` + INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at) + VALUES ($1, $2, $3, $4, $5, NOW()) + ON CONFLICT (provider, provider_user_id) + DO UPDATE SET + email = EXCLUDED.email, + name = EXCLUDED.name, + avatar_url = EXCLUDED.avatar_url, + last_login_at = NOW(), + updated_at = NOW() + RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at + ` + + var user models.User + err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan( + &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, + &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, + ) + if err != nil { + return nil, err + } + fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email) + + return &user, nil +} + +// GetUserByID retrieves user by ID +func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) { + query := ` + SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at + FROM users WHERE id = $1 + ` + + var user models.User + err := s.db.QueryRowContext(ctx, query, userID).Scan( + &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, + &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + + return &user, nil +} + +// GetUserByEmail retrieves user by email +func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) { + query := ` + SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at + FROM users WHERE email = $1 + ` + + var user models.User + err := s.db.QueryRowContext(ctx, query, email).Scan( + &user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider, + &user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt, + ) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + + return &user, nil +} diff --git a/backend/scripts/001_add_users_and_sessions.sql b/backend/scripts/001_add_users_and_sessions.sql new file mode 100644 index 0000000..5f3e1c0 --- /dev/null +++ b/backend/scripts/001_add_users_and_sessions.sql @@ -0,0 +1,52 @@ +-- Migration: Add users and sessions tables for authentication +-- Run this before 002_add_document_shares.sql + +-- Enable UUID extension +CREATE EXTENSION IF NOT EXISTS "uuid-ossp"; + +-- Users table +CREATE TABLE IF NOT EXISTS users ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + email VARCHAR(255) NOT NULL, + name VARCHAR(255) NOT NULL, + avatar_url TEXT, + provider VARCHAR(50) NOT NULL CHECK (provider IN ('google', 'github')), + provider_user_id VARCHAR(255) NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW(), + last_login_at TIMESTAMPTZ, + UNIQUE(provider, provider_user_id) +); + +CREATE INDEX idx_users_email ON users(email); +CREATE INDEX idx_users_provider ON users(provider, provider_user_id); + +COMMENT ON TABLE users IS 'Stores user accounts from OAuth providers'; +COMMENT ON COLUMN users.provider IS 'OAuth provider: google or github'; +COMMENT ON COLUMN users.provider_user_id IS 'User ID from OAuth provider'; + +-- Sessions table +CREATE TABLE IF NOT EXISTS sessions ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + token_hash VARCHAR(64) NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW(), + user_agent TEXT, + ip_address VARCHAR(45), + UNIQUE(token_hash) +); + +CREATE INDEX idx_sessions_user_id ON sessions(user_id); +CREATE INDEX idx_sessions_token_hash ON sessions(token_hash); +CREATE INDEX idx_sessions_expires_at ON sessions(expires_at); + +COMMENT ON TABLE sessions IS 'Stores active JWT sessions for revocation support'; +COMMENT ON COLUMN sessions.token_hash IS 'SHA-256 hash of JWT token'; +COMMENT ON COLUMN sessions.user_agent IS 'User agent string for device tracking'; + +-- Add owner_id to documents table if it doesn't exist +ALTER TABLE documents ADD COLUMN IF NOT EXISTS owner_id UUID REFERENCES users(id) ON DELETE SET NULL; +CREATE INDEX IF NOT EXISTS idx_documents_owner_id ON documents(owner_id); + +COMMENT ON COLUMN documents.owner_id IS 'User who created the document'; diff --git a/backend/scripts/002_add_document_shares.sql b/backend/scripts/002_add_document_shares.sql new file mode 100644 index 0000000..36090a4 --- /dev/null +++ b/backend/scripts/002_add_document_shares.sql @@ -0,0 +1,19 @@ +-- Migration: Add document sharing with permissions +-- Run against existing database + +CREATE TABLE IF NOT EXISTS document_shares ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + permission VARCHAR(20) NOT NULL CHECK (permission IN ('view', 'edit')), + created_at TIMESTAMPTZ DEFAULT NOW(), + created_by UUID REFERENCES users(id) ON DELETE SET NULL, + UNIQUE(document_id, user_id) +); + +CREATE INDEX idx_shares_document_id ON document_shares(document_id); +CREATE INDEX idx_shares_user_id ON document_shares(user_id); +CREATE INDEX idx_shares_permission ON document_shares(document_id, permission); + +COMMENT ON TABLE document_shares IS 'Stores per-user document access permissions'; +COMMENT ON COLUMN document_shares.permission IS 'Access level: view (read-only) or edit (read-write)'; diff --git a/frontend/src/hooks/useYjsDocument.ts b/frontend/src/hooks/useYjsDocument.ts index dd02874..a6784b2 100644 --- a/frontend/src/hooks/useYjsDocument.ts +++ b/frontend/src/hooks/useYjsDocument.ts @@ -1,10 +1,10 @@ import { useEffect, useState } from "react"; import { - createYjsDocument, - destroyYjsDocument, - getRandomColor, - getRandomName, - type YjsProviders, + createYjsDocument, + destroyYjsDocument, + getRandomColor, + getRandomName, + type YjsProviders, } from "../lib/yjs"; import { useAutoSave } from "./useAutoSave"; @@ -19,7 +19,6 @@ export const useYjsDocument = (documentId: string) => { let mounted = true; let currentProviders: YjsProviders | null = null; - // Create Yjs document and providers const initializeDocument = async () => { const yjsProviders = await createYjsDocument(documentId); currentProviders = yjsProviders; @@ -30,19 +29,75 @@ export const useYjsDocument = (documentId: string) => { } // Set user info for awareness + const userName = getRandomName(); + const userColor = getRandomColor(); yjsProviders.awareness.setLocalStateField("user", { - name: getRandomName(), - color: getRandomColor(), + name: userName, + color: userColor, }); + // NEW: Add awareness event logging + const handleAwarenessChange = ({ + added, + updated, + removed, + }: { + added: number[]; + updated: number[]; + removed: number[]; + }) => { + const states = yjsProviders.awareness.getStates(); + + added.forEach((clientId) => { + const state = states.get(clientId); + const user = state?.user; + console.log( + `[Awareness] User connected: ${ + user?.name || "Unknown" + } (ID: ${clientId})`, + { + color: user?.color, + clientId, + } + ); + }); + + updated.forEach((clientId) => { + const state = states.get(clientId); + const user = state?.user; + console.log( + `[Awareness] User updated: ${ + user?.name || "Unknown" + } (ID: ${clientId})` + ); + }); + + removed.forEach((clientId) => { + console.log(`[Awareness] User disconnected (ID: ${clientId})`); + }); + + console.log(`[Awareness] Total connected users: ${states.size}`); + }; + + yjsProviders.awareness.on("change", handleAwarenessChange); + // Listen for sync status yjsProviders.indexeddbProvider.on("synced", () => { console.log("IndexedDB synced"); setSynced(true); }); - yjsProviders.websocketProvider.on("status", (event: { status: string }) => { - console.log("WebSocket status:", event.status); + yjsProviders.websocketProvider.on( + "status", + (event: { status: string }) => { + console.log("WebSocket status:", event.status); + } + ); + + // Log local user info + console.log(`[Awareness] Local user initialized: ${userName}`, { + color: userColor, + clientId: yjsProviders.awareness.clientID, }); setProviders(yjsProviders); @@ -54,10 +109,13 @@ export const useYjsDocument = (documentId: string) => { return () => { mounted = false; if (currentProviders) { + console.log("[Awareness] Cleaning up local user"); + currentProviders.awareness.setLocalState(null); destroyYjsDocument(currentProviders); } }; }, [documentId]); + return { providers, synced }; }; diff --git a/frontend/src/lib/yjs.ts b/frontend/src/lib/yjs.ts index 4f8d75a..c8bc8c2 100644 --- a/frontend/src/lib/yjs.ts +++ b/frontend/src/lib/yjs.ts @@ -1,4 +1,5 @@ import { IndexeddbPersistence } from "y-indexeddb"; +import { Awareness } from "y-protocols/awareness"; import { WebsocketProvider } from "y-websocket"; import * as Y from "yjs"; import { documentsApi } from "../api/document"; @@ -9,7 +10,7 @@ export interface YjsProviders { ydoc: Y.Doc; websocketProvider: WebsocketProvider; indexeddbProvider: IndexeddbPersistence; - awareness: any; + awareness: Awareness; } export const createYjsDocument = async (documentId: string): Promise => {