feat: Enhance real-time collaboration features with user awareness and document sharing
- Added user information (UserID, UserName, UserAvatar) to Client struct for presence tracking. - Implemented failure handling in the broadcastMessage function to manage send failures and disconnect clients if necessary. - Introduced document ownership and sharing capabilities: - Added OwnerID and Is_Public fields to Document model. - Created DocumentShare model for managing document sharing with permissions. - Implemented functions for creating, listing, and managing document shares in the Postgres store. - Added user management functionality: - Created User model and associated functions for user management in the Postgres store. - Implemented session management with token hashing for security. - Updated database schema with migrations for users, sessions, and document shares. - Enhanced frontend Yjs integration with awareness event logging for user connections and disconnections.
This commit is contained in:
@@ -3,25 +3,31 @@ module github.com/M1ngdaXie/realtime-collab
|
||||
go 1.25.3
|
||||
|
||||
require (
|
||||
github.com/gin-contrib/cors v1.7.6
|
||||
github.com/gin-gonic/gin v1.11.0
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lib/pq v1.10.9
|
||||
golang.org/x/oauth2 v0.34.0
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/bytedance/sonic v1.14.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.9 // indirect
|
||||
github.com/gin-contrib/cors v1.7.6 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/gin-gonic/gin v1.11.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.27.0 // indirect
|
||||
github.com/goccy/go-json v0.10.5 // indirect
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/google/uuid v1.6.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.3 // indirect
|
||||
github.com/joho/godotenv v1.5.1 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/lib/pq v1.10.9 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ=
|
||||
github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA=
|
||||
github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA=
|
||||
@@ -5,9 +7,8 @@ github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFos
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8 h1:FfZ3gj38NjllZIeJAmMhr+qKL8Wu+nOoI3GqacKw1NM=
|
||||
github.com/gabriel-vasile/mimetype v1.4.8/go.mod h1:ByKUIKGjh1ODkGM1asKUbQZOLGrPjydw3hYPU2YU9t8=
|
||||
github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY=
|
||||
github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok=
|
||||
github.com/gin-contrib/cors v1.7.6 h1:3gQ8GMzs1Ylpf70y8bMw4fVpycXIeX1ZemuSQIsnQQY=
|
||||
@@ -16,18 +17,22 @@ github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
|
||||
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
|
||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
|
||||
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
|
||||
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
|
||||
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
|
||||
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
|
||||
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
|
||||
github.com/go-playground/validator/v10 v10.27.0 h1:w8+XrWVMhGkxOaaowyKH35gFydVHOvC0/uWoy2Fzwn4=
|
||||
github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAumgzX4CA1XNl4ZmDuVHKKo=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
|
||||
github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4=
|
||||
github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
|
||||
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
|
||||
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
|
||||
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
|
||||
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@@ -45,7 +50,6 @@ github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -53,6 +57,7 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
|
||||
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
@@ -65,6 +70,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
|
||||
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=
|
||||
@@ -79,6 +86,8 @@ golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
||||
golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs=
|
||||
golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8=
|
||||
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
|
||||
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
|
||||
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
|
||||
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
@@ -92,4 +101,5 @@ google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7I
|
||||
google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
|
||||
63
backend/internal/auth/jwt.go
Normal file
63
backend/internal/auth/jwt.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UserClaims defines the custom claims structure
|
||||
// Senior Tip: Embed information that helps you avoid DB lookups later.
|
||||
type UserClaims struct {
|
||||
Name string `json:"user_name"`
|
||||
Email string `json:"user_email"`
|
||||
AvatarURL *string `json:"avatar_url"` // Nullable avatar URL to avoid DB queries
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
// GenerateJWT creates a stateless JWT token for a user
|
||||
// Changed: Input is now userID (and optional role), not sessionID
|
||||
func GenerateJWT(userID uuid.UUID, name string, email string, avatarURL *string, secret string, expiresIn time.Duration) (string, error) {
|
||||
claims := UserClaims{
|
||||
Name: name,
|
||||
Email: email,
|
||||
AvatarURL: avatarURL,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
// Standard claim "Subject" is technically where UserID belongs,
|
||||
// but having a typed UserID field is easier for Go type assertions.
|
||||
Subject: userID.String(),
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(expiresIn)),
|
||||
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||
Issuer: "realtime-collab", // Your app name
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString([]byte(secret))
|
||||
}
|
||||
|
||||
// ValidateJWT parses the token and extracts the UserClaims
|
||||
// Changed: Returns *UserClaims so you can access UserID and Role directly
|
||||
func ValidateJWT(tokenString, secret string) (*UserClaims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// Security Check: Always validate the signing algorithm
|
||||
// to prevent "None" algorithm attacks.
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, errors.New("invalid signing method")
|
||||
}
|
||||
return []byte(secret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Type assertion to get our custom struct back
|
||||
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid token claims")
|
||||
}
|
||||
193
backend/internal/auth/middleware.go
Normal file
193
backend/internal/auth/middleware.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const UserContextKey contextKey = "user"
|
||||
const ContextUserIDKey = "user_id"
|
||||
|
||||
// AuthMiddleware provides auth middleware
|
||||
type AuthMiddleware struct {
|
||||
store store.Store
|
||||
jwtSecret string
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new auth middleware
|
||||
func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
store: store,
|
||||
jwtSecret: jwtSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth middleware requires valid authentication
|
||||
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
fmt.Println("🔒 RequireAuth: Starting authentication check")
|
||||
|
||||
user, claims, err := m.getUserFromToken(c)
|
||||
|
||||
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
|
||||
if claims != nil {
|
||||
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
|
||||
}
|
||||
|
||||
if err != nil || user == nil {
|
||||
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user)
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Note: Name and Email might be empty for old JWT tokens
|
||||
if claims.Name == "" || claims.Email == "" {
|
||||
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
|
||||
}
|
||||
|
||||
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
|
||||
c.Set(ContextUserIDKey, user)
|
||||
c.Set("user_email", claims.Email)
|
||||
c.Set("user_name", claims.Name)
|
||||
if claims.AvatarURL != nil {
|
||||
c.Set("avatar_url", *claims.AvatarURL)
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// OptionalAuth middleware sets user if authenticated, but doesn't require it
|
||||
func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
user, claims, _ := m.getUserFromToken(c)
|
||||
if user != nil {
|
||||
c.Set(string(UserContextKey), user)
|
||||
c.Set(ContextUserIDKey, user)
|
||||
if claims != nil {
|
||||
c.Set("user_email", claims.Email)
|
||||
c.Set("user_name", claims.Name)
|
||||
if claims.AvatarURL != nil {
|
||||
c.Set("avatar_url", *claims.AvatarURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// getUserFromToken parses the JWT and returns the UserID and the full Claims (for name/email)
|
||||
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
|
||||
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
|
||||
|
||||
if authHeader == "" {
|
||||
fmt.Println("⚠️ getUserFromToken: No Authorization header")
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// 必须要验证签名算法是 HMAC (HS256)
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(m.jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 2. 验证 Token 有效性并提取 Claims
|
||||
if claims, ok := token.Claims.(*UserClaims); ok && token.Valid {
|
||||
// 3. 把 String 类型的 Subject 转回 UUID
|
||||
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
|
||||
userID, err := uuid.Parse(claims.Subject)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
|
||||
return nil, nil, fmt.Errorf("invalid user ID in token")
|
||||
}
|
||||
|
||||
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
|
||||
// 这一步完全没有查数据库,速度极快
|
||||
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
|
||||
return &userID, claims, nil
|
||||
}
|
||||
|
||||
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
|
||||
return nil, nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
// GetUserFromContext extracts user ID from context
|
||||
func GetUserFromContext(c *gin.Context) *uuid.UUID {
|
||||
// 修正点:使用和存入时完全一样的 Key
|
||||
val, exists := c.Get(ContextUserIDKey)
|
||||
fmt.Println("within getFromContext the id is ... ")
|
||||
fmt.Println(val);
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 修正点:断言为 *uuid.UUID (因为我们在中间件里存的就是这个类型)
|
||||
uid, ok := val.(*uuid.UUID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
return uid
|
||||
}
|
||||
|
||||
// ValidateToken validates a JWT token and returns user ID, name, and avatar URL from JWT claims
|
||||
func (m *AuthMiddleware) ValidateToken(tokenString string) (*uuid.UUID, string, string, error) {
|
||||
// Parse and validate JWT
|
||||
claims, err := ValidateJWT(tokenString, m.jwtSecret)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("invalid token: %w", err)
|
||||
}
|
||||
|
||||
// Parse user ID from claims
|
||||
userID, err := uuid.Parse(claims.Subject)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("invalid user ID in token: %w", err)
|
||||
}
|
||||
|
||||
// Get session from database by token (for revocation capability)
|
||||
session, err := m.store.GetSessionByToken(context.Background(), tokenString)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("session not found: %w", err)
|
||||
}
|
||||
|
||||
// Verify session UserID matches JWT Subject
|
||||
if session.UserID != userID {
|
||||
return nil, "", "", fmt.Errorf("session ID mismatch")
|
||||
}
|
||||
|
||||
// Extract avatar URL from claims (handle nil gracefully)
|
||||
avatarURL := ""
|
||||
if claims.AvatarURL != nil {
|
||||
avatarURL = *claims.AvatarURL
|
||||
}
|
||||
|
||||
// Return user data from JWT claims - no DB query needed!
|
||||
return &userID, claims.Name, avatarURL, nil
|
||||
}
|
||||
32
backend/internal/auth/oauth.go
Normal file
32
backend/internal/auth/oauth.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/github"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
// GetGoogleOAuthConfig returns Google OAuth2 config
|
||||
func GetGoogleOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
},
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
}
|
||||
|
||||
// GetGitHubOAuthConfig returns GitHub OAuth2 config
|
||||
func GetGitHubOAuthConfig(clientID, clientSecret, redirectURL string) *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirectURL,
|
||||
Scopes: []string{"user:email"},
|
||||
Endpoint: github.Endpoint,
|
||||
}
|
||||
}
|
||||
302
backend/internal/handlers/auth.go
Normal file
302
backend/internal/handlers/auth.go
Normal file
@@ -0,0 +1,302 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type AuthHandler struct {
|
||||
store store.Store
|
||||
googleConfig *oauth2.Config
|
||||
githubConfig *oauth2.Config
|
||||
jwtSecret string
|
||||
frontendURL string
|
||||
}
|
||||
|
||||
func NewAuthHandler(store store.Store, jwtSecret, frontendURL string) *AuthHandler {
|
||||
googleConfig := auth.GetGoogleOAuthConfig(
|
||||
os.Getenv("GOOGLE_CLIENT_ID"),
|
||||
os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||
os.Getenv("GOOGLE_REDIRECT_URL"),
|
||||
)
|
||||
|
||||
githubConfig := auth.GetGitHubOAuthConfig(
|
||||
os.Getenv("GITHUB_CLIENT_ID"),
|
||||
os.Getenv("GITHUB_CLIENT_SECRET"),
|
||||
os.Getenv("GITHUB_REDIRECT_URL"),
|
||||
)
|
||||
|
||||
return &AuthHandler{
|
||||
store: store,
|
||||
googleConfig: googleConfig,
|
||||
githubConfig: githubConfig,
|
||||
jwtSecret: jwtSecret,
|
||||
frontendURL: frontendURL,
|
||||
}
|
||||
}
|
||||
|
||||
// GoogleLogin redirects to Google OAuth
|
||||
func (h *AuthHandler) GoogleLogin(c *gin.Context) {
|
||||
// Generate random state and set cookie
|
||||
oauthState := generateStateOauthCookie(c.Writer)
|
||||
url := h.googleConfig.AuthCodeURL(oauthState, oauth2.AccessTypeOffline)
|
||||
c.Redirect(http.StatusTemporaryRedirect, url)
|
||||
}
|
||||
|
||||
// GoogleCallback handles Google OAuth callback
|
||||
func (h *AuthHandler) GoogleCallback(c *gin.Context) {
|
||||
oauthState, err := c.Cookie("oauthstate")
|
||||
if err != nil || c.Query("state") != oauthState {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
|
||||
return
|
||||
}
|
||||
log.Println("Google callback state:", c.Query("state"))
|
||||
// Exchange code for token
|
||||
token, err := h.googleConfig.Exchange(context.Background(), c.Query("code"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user info from Google
|
||||
client := h.googleConfig.Client(context.Background(), token)
|
||||
resp, err := client.Get("https://www.googleapis.com/oauth2/v2/userinfo")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
|
||||
return
|
||||
}
|
||||
log.Println("Google user info response status:", resp.Status)
|
||||
log.Println("Google user info response headers:", resp.Header)
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, _ := io.ReadAll(resp.Body)
|
||||
var userInfo struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Picture string `json:"picture"`
|
||||
}
|
||||
json.Unmarshal(data, &userInfo)
|
||||
log.Println("Google user info:", userInfo)
|
||||
// Upsert user in database
|
||||
user, err := h.store.UpsertUser(
|
||||
c.Request.Context(),
|
||||
"google",
|
||||
userInfo.ID,
|
||||
userInfo.Email,
|
||||
userInfo.Name,
|
||||
&userInfo.Picture,
|
||||
)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create session and JWT
|
||||
jwt, err := h.createSessionAndJWT(c, user)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ DATABASE ERROR: %v\n", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": fmt.Sprintf("CreateSession Error: %v", err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to frontend with token
|
||||
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s", h.frontendURL, jwt)
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
}
|
||||
|
||||
// GithubLogin redirects to GitHub OAuth
|
||||
func (h *AuthHandler) GithubLogin(c *gin.Context) {
|
||||
url := h.githubConfig.AuthCodeURL("state", oauth2.AccessTypeOffline)
|
||||
c.Redirect(http.StatusTemporaryRedirect, url)
|
||||
}
|
||||
|
||||
// GithubCallback handles GitHub OAuth callback
|
||||
func (h *AuthHandler) GithubCallback(c *gin.Context) {
|
||||
code := c.Query("code")
|
||||
if code == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "No code provided"})
|
||||
return
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
token, err := h.githubConfig.Exchange(context.Background(), code)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to exchange token"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user info from GitHub
|
||||
client := h.githubConfig.Client(context.Background(), token)
|
||||
|
||||
// Get user profile
|
||||
resp, err := client.Get("https://api.github.com/user")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
data, _ := io.ReadAll(resp.Body)
|
||||
var userInfo struct {
|
||||
ID int `json:"id"`
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
}
|
||||
json.Unmarshal(data, &userInfo)
|
||||
|
||||
// If email is not public, fetch it separately
|
||||
if userInfo.Email == "" {
|
||||
emailResp, _ := client.Get("https://api.github.com/user/emails")
|
||||
if emailResp != nil {
|
||||
defer emailResp.Body.Close()
|
||||
emailData, _ := io.ReadAll(emailResp.Body)
|
||||
var emails []struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
}
|
||||
json.Unmarshal(emailData, &emails)
|
||||
for _, e := range emails {
|
||||
if e.Primary {
|
||||
userInfo.Email = e.Email
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Use login as name if name is empty
|
||||
if userInfo.Name == "" {
|
||||
userInfo.Name = userInfo.Login
|
||||
}
|
||||
|
||||
// Upsert user in database
|
||||
user, err := h.store.UpsertUser(
|
||||
c.Request.Context(),
|
||||
"github",
|
||||
fmt.Sprintf("%d", userInfo.ID),
|
||||
userInfo.Email,
|
||||
userInfo.Name,
|
||||
&userInfo.AvatarURL,
|
||||
)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create user"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create session and JWT
|
||||
jwt, err := h.createSessionAndJWT(c, user)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"})
|
||||
return
|
||||
}
|
||||
|
||||
// Redirect to frontend with token
|
||||
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s", h.frontendURL, jwt)
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
}
|
||||
|
||||
// Me returns current user info
|
||||
func (h *AuthHandler) Me(c *gin.Context) {
|
||||
userID := auth.GetUserFromContext(c)
|
||||
if userID == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.store.GetUserByID(c.Request.Context(), *userID)
|
||||
if err != nil || user == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.UserResponse{User: user})
|
||||
}
|
||||
|
||||
// Logout invalidates the session
|
||||
func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Already logged out"})
|
||||
return
|
||||
}
|
||||
|
||||
// Extract token
|
||||
token := ""
|
||||
if len(authHeader) > 7 && authHeader[:7] == "Bearer " {
|
||||
token = authHeader[7:]
|
||||
}
|
||||
|
||||
if token != "" {
|
||||
h.store.DeleteSession(c.Request.Context(), token)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
|
||||
}
|
||||
|
||||
// Helper: create session and JWT
|
||||
func (h *AuthHandler) createSessionAndJWT(c *gin.Context, user *models.User) (string, error) {
|
||||
expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7 days
|
||||
|
||||
// Generate JWT first (we need it for session) - now includes avatar URL
|
||||
jwt, err := auth.GenerateJWT(user.ID, user.Name, user.Email, user.AvatarURL, h.jwtSecret, 7*24*time.Hour)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Create session in database
|
||||
sessionID := uuid.New()
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
ipAddress := c.ClientIP()
|
||||
_, err = h.store.CreateSession(
|
||||
c.Request.Context(),
|
||||
user.ID,
|
||||
sessionID,
|
||||
jwt,
|
||||
expiresAt,
|
||||
&userAgent,
|
||||
&ipAddress,
|
||||
)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return jwt, nil
|
||||
}
|
||||
func generateStateOauthCookie(w http.ResponseWriter) string {
|
||||
b := make([]byte, 16)
|
||||
rand.Read(b)
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
cookie := http.Cookie{
|
||||
Name: "oauthstate",
|
||||
Value: state,
|
||||
Expires: time.Now().Add(10 * time.Minute),
|
||||
HttpOnly: true, // Prevents JavaScript access (XSS protection)
|
||||
Secure: false, // Must be false for http://localhost (set true in production)
|
||||
SameSite: http.SameSiteLaxMode, // ✅ Allows same-site OAuth redirects
|
||||
Path: "/", // ✅ Ensures cookie is sent to all backend paths
|
||||
}
|
||||
http.SetCookie(w, &cookie)
|
||||
|
||||
return state
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -10,135 +12,199 @@ import (
|
||||
)
|
||||
|
||||
type DocumentHandler struct {
|
||||
store *store.Store
|
||||
store *store.PostgresStore
|
||||
}
|
||||
|
||||
func NewDocumentHandler(s *store.Store) *DocumentHandler {
|
||||
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
|
||||
return &DocumentHandler{store: s}
|
||||
}
|
||||
|
||||
// CreateDocument creates a new document
|
||||
func (h *DocumentHandler) CreateDocument(c *gin.Context) {
|
||||
var req models.CreateDocumentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate document type
|
||||
if req.Type != models.DocumentTypeEditor && req.Type != models.DocumentTypeKanban {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document type"})
|
||||
return
|
||||
}
|
||||
// CreateDocument creates a new document (requires auth)
|
||||
func (h *DocumentHandler) CreateDocument(c *gin.Context) {
|
||||
fmt.Println("getting userId right now.... ")
|
||||
userID := auth.GetUserFromContext(c)
|
||||
fmt.Println(userID)
|
||||
if userID == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
doc, err := h.store.CreateDocument(req.Name, req.Type)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
var req models.CreateDocumentRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, doc)
|
||||
}
|
||||
// Create document with owner_id
|
||||
doc, err := h.store.CreateDocumentWithOwner(req.Name, req.Type, userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("Failed to create document: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, doc)
|
||||
}
|
||||
|
||||
// ListDocuments returns all documents
|
||||
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
|
||||
documents, err := h.store.ListDocuments()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
userID := auth.GetUserFromContext(c)
|
||||
|
||||
if documents == nil {
|
||||
documents = []models.Document{}
|
||||
}
|
||||
var docs []models.Document
|
||||
var err error
|
||||
|
||||
if userID != nil {
|
||||
// Authenticated: show owned + shared documents
|
||||
docs, err = h.store.ListUserDocuments(c.Request.Context(), *userID)
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("we dont know you: %v", err)})
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list documents"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.DocumentListResponse{
|
||||
Documents: docs,
|
||||
Total: len(docs),
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.DocumentListResponse{
|
||||
Documents: documents,
|
||||
Total: len(documents),
|
||||
})
|
||||
}
|
||||
|
||||
// GetDocument returns a single document
|
||||
func (h *DocumentHandler) GetDocument(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"})
|
||||
return
|
||||
}
|
||||
id, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
doc, err := h.store.GetDocument(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "document not found"})
|
||||
return
|
||||
}
|
||||
userID := auth.GetUserFromContext(c)
|
||||
|
||||
c.JSON(http.StatusOK, doc)
|
||||
}
|
||||
// Check permission if authenticated
|
||||
if userID != nil {
|
||||
canView, err := h.store.CanViewDocument(c.Request.Context(), id, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"})
|
||||
return
|
||||
}
|
||||
if !canView {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Access denied"})
|
||||
return
|
||||
}
|
||||
}else{
|
||||
c.JSON("this file is not public")
|
||||
return
|
||||
}
|
||||
|
||||
doc, err := h.store.GetDocument(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Document not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, doc)
|
||||
}
|
||||
// GetDocumentState returns the Yjs state for a document
|
||||
func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"})
|
||||
return
|
||||
}
|
||||
// GetDocumentState retrieves document state (requires view permission)
|
||||
func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
||||
id, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
doc, err := h.store.GetDocument(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "document not found"})
|
||||
return
|
||||
}
|
||||
userID := auth.GetUserFromContext(c)
|
||||
|
||||
// Return binary state
|
||||
if doc.YjsState == nil {
|
||||
c.Data(http.StatusOK, "application/octet-stream", []byte{})
|
||||
return
|
||||
}
|
||||
// Check permission if authenticated
|
||||
if userID != nil {
|
||||
canView, err := h.store.CanViewDocument(c.Request.Context(), id, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"})
|
||||
return
|
||||
}
|
||||
if !canView {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Access denied"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Data(http.StatusOK, "application/octet-stream", doc.YjsState)
|
||||
}
|
||||
doc, err := h.store.GetDocument(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "Document not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateDocumentState updates the Yjs state for a document
|
||||
func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"})
|
||||
return
|
||||
}
|
||||
c.Data(http.StatusOK, "application/octet-stream", doc.YjsState)
|
||||
}
|
||||
|
||||
// Read binary body
|
||||
state, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
|
||||
return
|
||||
}
|
||||
// UpdateDocumentState updates document state (requires edit permission)
|
||||
func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
|
||||
id, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.store.UpdateDocumentState(id, state)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
userID := auth.GetUserFromContext(c)
|
||||
if userID == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "state updated successfully"})
|
||||
}
|
||||
// Check edit permission
|
||||
canEdit, err := h.store.CanEditDocument(c.Request.Context(), id, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"})
|
||||
return
|
||||
}
|
||||
if !canEdit {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Edit access denied"})
|
||||
return
|
||||
}
|
||||
|
||||
// DeleteDocument deletes a document
|
||||
func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
|
||||
idStr := c.Param("id")
|
||||
id, err := uuid.Parse(idStr)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid document ID"})
|
||||
return
|
||||
}
|
||||
var req models.UpdateStateRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.store.DeleteDocument(id)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "document not found"})
|
||||
return
|
||||
}
|
||||
if err := h.store.UpdateDocumentState(id, req.State); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to update state"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "document deleted successfully"})
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
|
||||
}
|
||||
|
||||
// DeleteDocument deletes a document (owner only)
|
||||
func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
|
||||
id, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := auth.GetUserFromContext(c)
|
||||
if userID == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check ownership
|
||||
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), id, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
|
||||
return
|
||||
}
|
||||
if !isOwner {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can delete documents"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.store.DeleteDocument(id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete document"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"})
|
||||
}
|
||||
286
backend/internal/handlers/share.go
Normal file
286
backend/internal/handlers/share.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os" // Add this
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type ShareHandler struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func NewShareHandler(store store.Store) *ShareHandler {
|
||||
return &ShareHandler{store: store}
|
||||
}
|
||||
|
||||
// CreateShare creates a new document share
|
||||
func (h *ShareHandler) CreateShare(c *gin.Context) {
|
||||
userID := auth.GetUserFromContext(c)
|
||||
if userID == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
documentID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is owner
|
||||
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
|
||||
return
|
||||
}
|
||||
if !isOwner {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can share documents"})
|
||||
return
|
||||
}
|
||||
|
||||
var req models.CreateShareRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Get user by email
|
||||
targetUser, err := h.store.GetUserByEmail(c.Request.Context(), req.UserEmail)
|
||||
if err != nil || targetUser == nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "User not found"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create share
|
||||
share, err := h.store.CreateDocumentShare(
|
||||
c.Request.Context(),
|
||||
documentID,
|
||||
targetUser.ID,
|
||||
req.Permission,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create share"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusCreated, share)
|
||||
}
|
||||
|
||||
// ListShares lists all shares for a document
|
||||
func (h *ShareHandler) ListShares(c *gin.Context) {
|
||||
userID := auth.GetUserFromContext(c)
|
||||
if userID == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
documentID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is owner
|
||||
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
|
||||
return
|
||||
}
|
||||
if !isOwner {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can view shares"})
|
||||
return
|
||||
}
|
||||
|
||||
shares, err := h.store.ListDocumentShares(c.Request.Context(), documentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to list shares"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, models.ShareListResponse{Shares: shares})
|
||||
}
|
||||
|
||||
// DeleteShare removes a share
|
||||
func (h *ShareHandler) DeleteShare(c *gin.Context) {
|
||||
userID := auth.GetUserFromContext(c)
|
||||
if userID == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
documentID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
targetUserID, err := uuid.Parse(c.Param("userId"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid user ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is owner
|
||||
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
|
||||
return
|
||||
}
|
||||
if !isOwner {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Only owner can delete shares"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.store.DeleteDocumentShare(c.Request.Context(), documentID, targetUserID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to delete share"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Share deleted successfully"})
|
||||
}
|
||||
// CreateShareLink generates a public share link
|
||||
func (h *ShareHandler) CreateShareLink(c *gin.Context) {
|
||||
documentID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := auth.GetUserFromContext(c)
|
||||
if userID == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is owner
|
||||
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
|
||||
return
|
||||
}
|
||||
if !isOwner {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Only document owner can create share links"})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse request body
|
||||
var req struct {
|
||||
Permission string `json:"permission" binding:"required,oneof=view edit"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Permission must be 'view' or 'edit'"})
|
||||
return
|
||||
}
|
||||
|
||||
// Generate share token
|
||||
token, err := h.store.GenerateShareToken(c.Request.Context(), documentID, req.Permission)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate share link"})
|
||||
return
|
||||
}
|
||||
|
||||
// Get frontend URL from env
|
||||
frontendURL := os.Getenv("FRONTEND_URL")
|
||||
if frontendURL == "" {
|
||||
frontendURL = "http://localhost:5173"
|
||||
}
|
||||
|
||||
shareURL := fmt.Sprintf("%s/editor/%s?share=%s", frontendURL, documentID.String(), token)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"url": shareURL,
|
||||
"token": token,
|
||||
"permission": req.Permission,
|
||||
})
|
||||
}
|
||||
|
||||
// GetShareLink retrieves the current public share link
|
||||
func (h *ShareHandler) GetShareLink(c *gin.Context) {
|
||||
documentID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := auth.GetUserFromContext(c)
|
||||
if userID == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is owner
|
||||
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
|
||||
return
|
||||
}
|
||||
if !isOwner {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Only document owner can view share links"})
|
||||
return
|
||||
}
|
||||
|
||||
token, exists, err := h.store.GetShareToken(c.Request.Context(), documentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get share link"})
|
||||
return
|
||||
}
|
||||
if !exists {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "No public share link exists"})
|
||||
return
|
||||
}
|
||||
|
||||
frontendURL := os.Getenv("FRONTEND_URL")
|
||||
if frontendURL == "" {
|
||||
frontendURL = "http://localhost:5173"
|
||||
}
|
||||
|
||||
shareURL := fmt.Sprintf("%s/editor/%s?share=%s", frontendURL, documentID.String(), token)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"url": shareURL,
|
||||
"token": token,
|
||||
})
|
||||
}
|
||||
|
||||
// RevokeShareLink removes the public share link
|
||||
func (h *ShareHandler) RevokeShareLink(c *gin.Context) {
|
||||
documentID, err := uuid.Parse(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
userID := auth.GetUserFromContext(c)
|
||||
if userID == nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if user is owner
|
||||
isOwner, err := h.store.IsDocumentOwner(c.Request.Context(), documentID, *userID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check ownership"})
|
||||
return
|
||||
}
|
||||
if !isOwner {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Only document owner can revoke share links"})
|
||||
return
|
||||
}
|
||||
|
||||
err = h.store.RevokeShareToken(c.Request.Context(), documentID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to revoke share link"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Share link revoked"})
|
||||
}
|
||||
@@ -3,57 +3,147 @@ package handlers
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/hub"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
// Allow all origins for development
|
||||
// TODO: Restrict in production
|
||||
return true
|
||||
},
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
// Check origin against allowed origins from environment
|
||||
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
|
||||
if allowedOrigins == "" {
|
||||
// Default for development
|
||||
origin := r.Header.Get("Origin")
|
||||
return origin == "http://localhost:5173" || origin == "http://localhost:3000"
|
||||
}
|
||||
// Production: validate against ALLOWED_ORIGINS
|
||||
// TODO: Parse and validate origin
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
type WebSocketHandler struct {
|
||||
hub *hub.Hub
|
||||
}
|
||||
type WebSocketHandler struct {
|
||||
hub *hub.Hub
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func NewWebSocketHandler(h *hub.Hub) *WebSocketHandler {
|
||||
return &WebSocketHandler{hub: h}
|
||||
}
|
||||
func NewWebSocketHandler(h *hub.Hub, s store.Store) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
hub: h,
|
||||
store: s,
|
||||
}
|
||||
}
|
||||
|
||||
func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context){
|
||||
func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||
roomID := c.Param("roomId")
|
||||
|
||||
if(roomID == ""){
|
||||
if roomID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "roomId is required"})
|
||||
return
|
||||
}
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
|
||||
// Parse document ID
|
||||
documentID, err := uuid.Parse(roomID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upgrade to WebSocket"})
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid document ID"})
|
||||
return
|
||||
}
|
||||
|
||||
// Create a new client
|
||||
clientID := uuid.New().String()
|
||||
client := hub.NewClient(clientID, conn, wsh.hub, roomID)
|
||||
// Try to authenticate via JWT token or share token
|
||||
var userID *uuid.UUID
|
||||
var userName string
|
||||
var userAvatar *string
|
||||
authenticated := false
|
||||
|
||||
// Register client with hub
|
||||
wsh.hub.Register <- client
|
||||
// Check for JWT token in query parameter
|
||||
jwtToken := c.Query("token")
|
||||
if jwtToken != "" {
|
||||
// Validate JWT and get user data from token claims (no DB query!)
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
log.Println("JWT_SECRET not configured")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Server configuration error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Start read and write pumps in separate goroutines
|
||||
go client.WritePump()
|
||||
go client.ReadPump()
|
||||
authMiddleware := auth.NewAuthMiddleware(wsh.store, jwtSecret)
|
||||
uid, name, avatar, err := authMiddleware.ValidateToken(jwtToken)
|
||||
if err == nil && uid != nil {
|
||||
// User data comes directly from JWT claims - no DB query needed!
|
||||
userID = uid
|
||||
userName = name
|
||||
if avatar != "" {
|
||||
userAvatar = &avatar
|
||||
}
|
||||
authenticated = true
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("WebSocket connection established for client %s in room %s", clientID, roomID)
|
||||
}
|
||||
// If not authenticated via JWT, check for share token
|
||||
if !authenticated {
|
||||
shareToken := c.Query("share")
|
||||
if shareToken != "" {
|
||||
// Validate share token
|
||||
valid, err := wsh.store.ValidateShareToken(c.Request.Context(), documentID, shareToken)
|
||||
if err != nil {
|
||||
log.Printf("Error validating share token: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to validate share token"})
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "Invalid or expired share token"})
|
||||
return
|
||||
}
|
||||
// Share token is valid, allow connection with anonymous user
|
||||
userName = "Anonymous"
|
||||
authenticated = true
|
||||
}
|
||||
}
|
||||
|
||||
// If still not authenticated, reject connection
|
||||
if !authenticated {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required. Provide 'token' or 'share' query parameter"})
|
||||
return
|
||||
}
|
||||
|
||||
// If authenticated with JWT, check document permissions
|
||||
if userID != nil {
|
||||
canView, err := wsh.store.CanViewDocument(c.Request.Context(), documentID, *userID)
|
||||
if err != nil {
|
||||
log.Printf("Error checking permissions: %v", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"})
|
||||
return
|
||||
}
|
||||
if !canView {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "You don't have permission to access this document"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Upgrade connection
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to upgrade connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Create client with user information
|
||||
clientID := uuid.New().String()
|
||||
client := hub.NewClient(clientID, userID, userName, userAvatar, conn, wsh.hub, roomID)
|
||||
|
||||
// Register client
|
||||
wsh.hub.Register <- client
|
||||
|
||||
// Start goroutines
|
||||
go client.WritePump()
|
||||
go client.ReadPump()
|
||||
|
||||
log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ package hub
|
||||
import (
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
@@ -14,11 +16,20 @@ type Message struct {
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
ID string
|
||||
Conn *websocket.Conn
|
||||
send chan []byte
|
||||
hub *Hub
|
||||
roomID string
|
||||
ID string
|
||||
UserID *uuid.UUID // Authenticated user ID (nil for public share access)
|
||||
UserName string // User's display name for presence
|
||||
UserAvatar *string // User's avatar URL for presence
|
||||
Conn *websocket.Conn
|
||||
send chan []byte
|
||||
sendMu sync.Mutex
|
||||
sendClosed bool
|
||||
hub *Hub
|
||||
roomID string
|
||||
mutex sync.Mutex
|
||||
unregisterOnce sync.Once
|
||||
failureCount int
|
||||
failureMu sync.Mutex
|
||||
}
|
||||
type Room struct {
|
||||
ID string
|
||||
@@ -74,54 +85,99 @@ func (h *Hub) registerClient(client *Client) {
|
||||
log.Printf("Client %s joined room %s (total clients: %d)", client.ID, client.roomID, len(room.clients))
|
||||
}
|
||||
func (h *Hub) unregisterClient(client *Client) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
room, exists := h.rooms[client.roomID]
|
||||
if !exists {
|
||||
log.Printf("Room %s does not exist for client %s", client.roomID, client.ID)
|
||||
return
|
||||
}
|
||||
room.mu.Lock()
|
||||
if _, ok := room.clients[client]; ok {
|
||||
delete(room.clients, client)
|
||||
close(client.send)
|
||||
log.Printf("Client %s disconnected from room %s", client.ID, client.roomID)
|
||||
}
|
||||
room, exists := h.rooms[client.roomID]
|
||||
if !exists {
|
||||
log.Printf("Room %s does not exist for client %s", client.roomID, client.ID)
|
||||
return
|
||||
}
|
||||
|
||||
room.mu.Unlock()
|
||||
log.Printf("Client %s left room %s (total clients: %d)", client.ID, client.roomID, len(room.clients))
|
||||
room.mu.Lock()
|
||||
defer room.mu.Unlock()
|
||||
|
||||
if len(room.clients) == 0 {
|
||||
delete(h.rooms, client.roomID)
|
||||
log.Printf("Deleted empty room with ID: %s", client.roomID)
|
||||
}
|
||||
if _, ok := room.clients[client]; ok {
|
||||
delete(room.clients, client)
|
||||
|
||||
// Safely close send channel exactly once
|
||||
client.sendMu.Lock()
|
||||
if !client.sendClosed {
|
||||
close(client.send)
|
||||
client.sendClosed = true
|
||||
}
|
||||
client.sendMu.Unlock()
|
||||
|
||||
log.Printf("Client %s disconnected from room %s (total clients: %d)",
|
||||
client.ID, client.roomID, len(room.clients))
|
||||
}
|
||||
|
||||
if len(room.clients) == 0 {
|
||||
delete(h.rooms, client.roomID)
|
||||
log.Printf("Deleted empty room with ID: %s", client.roomID)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
writeWait = 10 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
pingPeriod = (pongWait * 9) / 10 // 54 seconds
|
||||
maxSendFailures = 5
|
||||
)
|
||||
|
||||
func (h *Hub) broadcastMessage(message *Message) {
|
||||
h.mu.RLock()
|
||||
room, exists := h.rooms[message.RoomID]
|
||||
h.mu.RUnlock()
|
||||
if !exists {
|
||||
log.Printf("Room %s does not exist for broadcasting", message.RoomID)
|
||||
return
|
||||
}
|
||||
h.mu.RLock()
|
||||
room, exists := h.rooms[message.RoomID]
|
||||
h.mu.RUnlock()
|
||||
if !exists {
|
||||
log.Printf("Room %s does not exist for broadcasting", message.RoomID)
|
||||
return
|
||||
}
|
||||
|
||||
room.mu.RLock()
|
||||
defer room.mu.RUnlock()
|
||||
for client := range room.clients {
|
||||
if client != message.sender {
|
||||
select {
|
||||
case client.send <- message.Data:
|
||||
default:
|
||||
log.Printf("Failed to send to client %s (channel full)", client.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
room.mu.RLock()
|
||||
defer room.mu.RUnlock()
|
||||
|
||||
for client := range room.clients {
|
||||
if client != message.sender {
|
||||
select {
|
||||
case client.send <- message.Data:
|
||||
// Success - reset failure count
|
||||
client.failureMu.Lock()
|
||||
client.failureCount = 0
|
||||
client.failureMu.Unlock()
|
||||
|
||||
default:
|
||||
// Failed - increment failure count
|
||||
client.failureMu.Lock()
|
||||
client.failureCount++
|
||||
currentFailures := client.failureCount
|
||||
client.failureMu.Unlock()
|
||||
|
||||
log.Printf("Failed to send to client %s (channel full, failures: %d/%d)",
|
||||
client.ID, currentFailures, maxSendFailures)
|
||||
|
||||
// Disconnect if threshold exceeded
|
||||
if currentFailures >= maxSendFailures {
|
||||
log.Printf("Client %s exceeded max send failures, disconnecting", client.ID)
|
||||
go func(c *Client) {
|
||||
c.unregister()
|
||||
c.Conn.Close()
|
||||
}(client)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func (c *Client) ReadPump() {
|
||||
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
c.Conn.SetPongHandler(func(string) error {
|
||||
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
defer func() {
|
||||
c.hub.Unregister <- c
|
||||
c.unregister()
|
||||
c.Conn.Close()
|
||||
}()
|
||||
for {
|
||||
@@ -141,24 +197,54 @@ func (c *Client) ReadPump() {
|
||||
}
|
||||
|
||||
func (c *Client) WritePump() {
|
||||
defer func() {
|
||||
c.Conn.Close()
|
||||
}()
|
||||
for message := range c.send {
|
||||
err := c.Conn.WriteMessage(websocket.BinaryMessage, message)
|
||||
if err != nil {
|
||||
log.Printf("Error writing message to client %s: %v", c.ID, err)
|
||||
break
|
||||
}
|
||||
}
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.unregister() // NEW: Now WritePump also unregisters
|
||||
c.Conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.send:
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if !ok {
|
||||
// Hub closed the channel
|
||||
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
err := c.Conn.WriteMessage(websocket.BinaryMessage, message)
|
||||
if err != nil {
|
||||
log.Printf("Error writing message to client %s: %v", c.ID, err)
|
||||
return
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
log.Printf("Ping failed for client %s: %v", c.ID, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewClient(id string, conn *websocket.Conn, hub *Hub, roomID string) *Client {
|
||||
|
||||
func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string, conn *websocket.Conn, hub *Hub, roomID string) *Client {
|
||||
return &Client{
|
||||
ID: id,
|
||||
Conn: conn,
|
||||
send: make(chan []byte, 256),
|
||||
hub: hub,
|
||||
roomID: roomID,
|
||||
ID: id,
|
||||
UserID: userID,
|
||||
UserName: userName,
|
||||
UserAvatar: userAvatar,
|
||||
Conn: conn,
|
||||
send: make(chan []byte, 256),
|
||||
hub: hub,
|
||||
roomID: roomID,
|
||||
}
|
||||
}
|
||||
func (c *Client) unregister() {
|
||||
c.unregisterOnce.Do(func() {
|
||||
c.hub.Unregister <- c
|
||||
})
|
||||
}
|
||||
@@ -14,14 +14,17 @@ const (
|
||||
)
|
||||
|
||||
type Document struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type DocumentType `json:"type"`
|
||||
YjsState []byte `json:"-"` // Don't expose binary data in JSON
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
ID uuid.UUID `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type DocumentType `json:"type"`
|
||||
YjsState []byte `json:"-"`
|
||||
OwnerID *uuid.UUID `json:"owner_id"` // NEW
|
||||
Is_Public bool `json:"is_public"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
|
||||
type CreateDocumentRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Type DocumentType `json:"type" binding:"required"`
|
||||
|
||||
30
backend/internal/models/share.go
Normal file
30
backend/internal/models/share.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DocumentShare struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
DocumentID uuid.UUID `json:"document_id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Permission string `json:"permission"` // "view" or "edit"
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CreatedBy *uuid.UUID `json:"created_by"`
|
||||
}
|
||||
|
||||
type CreateShareRequest struct {
|
||||
UserEmail string `json:"user_email" binding:"required"`
|
||||
Permission string `json:"permission" binding:"required,oneof=view edit"`
|
||||
}
|
||||
|
||||
type ShareListResponse struct {
|
||||
Shares []DocumentShareWithUser `json:"shares"`
|
||||
}
|
||||
|
||||
type DocumentShareWithUser struct {
|
||||
DocumentShare
|
||||
User User `json:"user"`
|
||||
}
|
||||
48
backend/internal/models/user.go
Normal file
48
backend/internal/models/user.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
AvatarURL *string `json:"avatar_url"`
|
||||
Provider string `json:"provider"`
|
||||
ProviderUserID string `json:"-"` // Don't expose
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
LastLoginAt *time.Time `json:"last_login_at"`
|
||||
}
|
||||
|
||||
type Session struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
TokenHash string `json:"-"` // SHA-256 hash of JWT
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UserAgent *string `json:"user_agent"`
|
||||
IPAddress *string `json:"ip_address"`
|
||||
}
|
||||
|
||||
type OAuthToken struct {
|
||||
ID uuid.UUID `json:"id"`
|
||||
UserID uuid.UUID `json:"user_id"`
|
||||
Provider string `json:"provider"`
|
||||
AccessToken string `json:"-"` // Don't expose
|
||||
RefreshToken *string `json:"-"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
Scope *string `json:"scope"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// Response for /auth/me endpoint
|
||||
type UserResponse struct {
|
||||
User *User `json:"user"`
|
||||
Token string `json:"token,omitempty"`
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
@@ -10,11 +11,49 @@ import (
|
||||
_ "github.com/lib/pq" // PostgreSQL driver
|
||||
)
|
||||
|
||||
type Store struct{
|
||||
db *sql.DB
|
||||
// Store interface defines all database operations
|
||||
type Store interface {
|
||||
// Document operations
|
||||
CreateDocument(name string, docType models.DocumentType) (*models.Document, error)
|
||||
CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) // ADD THIS
|
||||
GetDocument(id uuid.UUID) (*models.Document, error)
|
||||
ListDocuments() ([]models.Document, error)
|
||||
ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) // ADD THIS
|
||||
UpdateDocumentState(id uuid.UUID, state []byte) error
|
||||
DeleteDocument(id uuid.UUID) error
|
||||
|
||||
// User operations
|
||||
UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error)
|
||||
GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error)
|
||||
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
|
||||
|
||||
// Session operations
|
||||
CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error)
|
||||
GetSessionByToken(ctx context.Context, token string) (*models.Session, error)
|
||||
DeleteSession(ctx context.Context, token string) error
|
||||
CleanupExpiredSessions(ctx context.Context) error
|
||||
|
||||
// Share operations
|
||||
CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, error)
|
||||
ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error)
|
||||
DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error
|
||||
CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||
CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||
IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error)
|
||||
GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error)
|
||||
ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error)
|
||||
RevokeShareToken(ctx context.Context, documentID uuid.UUID) error
|
||||
GetShareToken(ctx context.Context, documentID uuid.UUID) (string, bool, error)
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
func NewStore(databaseUrl string) (*Store, error) {
|
||||
|
||||
type PostgresStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewPostgresStore(databaseUrl string) (*PostgresStore, error) {
|
||||
db, error := sql.Open("postgres", databaseUrl)
|
||||
if error != nil {
|
||||
return nil, error
|
||||
@@ -25,14 +64,14 @@ func NewStore(databaseUrl string) (*Store, error) {
|
||||
db.SetMaxOpenConns(25)
|
||||
db.SetMaxIdleConns(5)
|
||||
db.SetConnMaxLifetime(5 * time.Minute)
|
||||
return &Store{db: db}, nil
|
||||
return &PostgresStore{db: db}, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
func (s *PostgresStore) Close() error {
|
||||
return s.db.Close()
|
||||
}
|
||||
|
||||
func (s *Store) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
|
||||
func (s *PostgresStore) CreateDocument(name string, docType models.DocumentType) (*models.Document, error) {
|
||||
doc := &models.Document{
|
||||
ID: uuid.New(),
|
||||
Name: name,
|
||||
@@ -62,7 +101,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model
|
||||
}
|
||||
|
||||
// GetDocument retrieves a document by ID
|
||||
func (s *Store) GetDocument(id uuid.UUID) (*models.Document, error) {
|
||||
func (s *PostgresStore) GetDocument(id uuid.UUID) (*models.Document, error) {
|
||||
doc := &models.Document{}
|
||||
|
||||
query := `
|
||||
@@ -92,7 +131,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model
|
||||
|
||||
|
||||
// ListDocuments retrieves all documents
|
||||
func (s *Store) ListDocuments() ([]models.Document, error) {
|
||||
func (s *PostgresStore) ListDocuments() ([]models.Document, error) {
|
||||
query := `
|
||||
SELECT id, name, type, created_at, updated_at
|
||||
FROM documents
|
||||
@@ -118,7 +157,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateDocumentState(id uuid.UUID, state []byte) error {
|
||||
func (s *PostgresStore) UpdateDocumentState(id uuid.UUID, state []byte) error {
|
||||
query := `
|
||||
UPDATE documents
|
||||
SET yjs_state = $1, updated_at = $2
|
||||
@@ -142,7 +181,7 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteDocument(id uuid.UUID) error {
|
||||
func (s *PostgresStore) DeleteDocument(id uuid.UUID) error {
|
||||
query := `DELETE FROM documents WHERE id = $1`
|
||||
|
||||
result, err := s.db.Exec(query, id)
|
||||
@@ -162,3 +201,88 @@ func (s *Store) CreateDocument(name string, docType models.DocumentType) (*model
|
||||
return nil
|
||||
}
|
||||
|
||||
// CreateDocumentWithOwner creates a new document with owner
|
||||
func (s *PostgresStore) CreateDocumentWithOwner(name string, docType models.DocumentType, ownerID *uuid.UUID) (*models.Document, error) {
|
||||
// 1. 检查 docType 是否为空,或者是否合法 (防止 check constraint 报错)
|
||||
if docType == "" {
|
||||
docType = models.DocumentTypeEditor // Default to editor instead of invalid "text"
|
||||
}
|
||||
|
||||
// Validate that docType is one of the allowed values
|
||||
if docType != models.DocumentTypeEditor && docType != models.DocumentTypeKanban {
|
||||
return nil, fmt.Errorf("invalid document type: %s (must be 'editor' or 'kanban')", docType)
|
||||
}
|
||||
|
||||
doc := &models.Document{
|
||||
ID: uuid.New(),
|
||||
Name: name,
|
||||
Type: docType,
|
||||
YjsState: []byte{}, // 这里初始化了空字节
|
||||
OwnerID: ownerID,
|
||||
Is_Public: false, // 显式设置默认值
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
// 2. 补全了 yjs_state 和 is_public
|
||||
query := `
|
||||
INSERT INTO documents (id, name, type, owner_id, yjs_state, is_public, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING id, name, type, owner_id, yjs_state, is_public, created_at, updated_at
|
||||
`
|
||||
|
||||
// 3. Scan 的时候也要对应加上
|
||||
err := s.db.QueryRow(query,
|
||||
doc.ID,
|
||||
doc.Name,
|
||||
doc.Type,
|
||||
doc.OwnerID,
|
||||
doc.YjsState, // $5
|
||||
doc.Is_Public, // $6
|
||||
doc.CreatedAt,
|
||||
doc.UpdatedAt,
|
||||
).Scan(
|
||||
&doc.ID,
|
||||
&doc.Name,
|
||||
&doc.Type,
|
||||
&doc.OwnerID,
|
||||
&doc.YjsState, // Scan 回来
|
||||
&doc.Is_Public, // Scan 回来
|
||||
&doc.CreatedAt,
|
||||
&doc.UpdatedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create document: %w", err)
|
||||
}
|
||||
return doc, nil
|
||||
}
|
||||
|
||||
// ListUserDocuments lists documents owned by or shared with a user
|
||||
func (s *PostgresStore) ListUserDocuments(ctx context.Context, userID uuid.UUID) ([]models.Document, error) {
|
||||
query := `
|
||||
SELECT DISTINCT d.id, d.name, d.type, d.owner_id, d.created_at, d.updated_at
|
||||
FROM documents d
|
||||
LEFT JOIN document_shares ds ON d.id = ds.document_id
|
||||
WHERE d.owner_id = $1 OR ds.user_id = $1
|
||||
ORDER BY d.created_at DESC
|
||||
`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list user documents: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var documents []models.Document
|
||||
for rows.Next() {
|
||||
var doc models.Document
|
||||
err := rows.Scan(&doc.ID, &doc.Name, &doc.Type, &doc.OwnerID, &doc.CreatedAt, &doc.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan document: %w", err)
|
||||
}
|
||||
documents = append(documents, doc)
|
||||
}
|
||||
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
88
backend/internal/store/session.go
Normal file
88
backend/internal/store/session.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"time"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CreateSession creates a new session
|
||||
func (s *PostgresStore) CreateSession(ctx context.Context, userID uuid.UUID, sessionID uuid.UUID, token string, expiresAt time.Time, userAgent, ipAddress *string) (*models.Session, error) {
|
||||
// Hash the token before storing
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
// 【修改点 1】: 在 SQL 里显式加上 id 字段
|
||||
// 注意:$1 变成了 id,后面的参数序号全部要顺延 (+1)
|
||||
query := `
|
||||
INSERT INTO sessions (id, user_id, token_hash, expires_at, user_agent, ip_address)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)
|
||||
RETURNING id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
|
||||
`
|
||||
|
||||
var session models.Session
|
||||
// 【修改点 2】: 在参数列表的最前面加上 sessionID
|
||||
// 现在的对应关系:
|
||||
// $1 -> sessionID
|
||||
// $2 -> userID
|
||||
// $3 -> tokenHash
|
||||
// ...
|
||||
err := s.db.QueryRowContext(ctx, query,
|
||||
sessionID, // <--- 这里!把它传进去!
|
||||
userID,
|
||||
tokenHash,
|
||||
expiresAt,
|
||||
userAgent,
|
||||
ipAddress,
|
||||
).Scan(
|
||||
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
|
||||
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// GetSessionByToken retrieves session by JWT token
|
||||
func (s *PostgresStore) GetSessionByToken(ctx context.Context, token string) (*models.Session, error) {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
query := `
|
||||
SELECT id, user_id, token_hash, expires_at, created_at, user_agent, ip_address
|
||||
FROM sessions
|
||||
WHERE token_hash = $1 AND expires_at > NOW()
|
||||
`
|
||||
|
||||
var session models.Session
|
||||
err := s.db.QueryRowContext(ctx, query, tokenHash).Scan(
|
||||
&session.ID, &session.UserID, &session.TokenHash, &session.ExpiresAt,
|
||||
&session.CreatedAt, &session.UserAgent, &session.IPAddress,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// DeleteSession deletes a session (logout)
|
||||
func (s *PostgresStore) DeleteSession(ctx context.Context, token string) error {
|
||||
hash := sha256.Sum256([]byte(token))
|
||||
tokenHash := hex.EncodeToString(hash[:])
|
||||
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE token_hash = $1", tokenHash)
|
||||
return err
|
||||
}
|
||||
|
||||
// CleanupExpiredSessions removes expired sessions
|
||||
func (s *PostgresStore) CleanupExpiredSessions(ctx context.Context) error {
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM sessions WHERE expires_at < NOW()")
|
||||
return err
|
||||
}
|
||||
193
backend/internal/store/share.go
Normal file
193
backend/internal/store/share.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// CreateDocumentShare creates a new share
|
||||
func (s *PostgresStore) CreateDocumentShare(ctx context.Context, documentID, userID uuid.UUID, permission string, createdBy *uuid.UUID) (*models.DocumentShare, error) {
|
||||
query := `
|
||||
INSERT INTO document_shares (document_id, user_id, permission, created_by)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (document_id, user_id) DO UPDATE SET permission = EXCLUDED.permission
|
||||
RETURNING id, document_id, user_id, permission, created_at, created_by
|
||||
`
|
||||
|
||||
var share models.DocumentShare
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, userID, permission, createdBy).Scan(
|
||||
&share.ID, &share.DocumentID, &share.UserID, &share.Permission,
|
||||
&share.CreatedAt, &share.CreatedBy,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &share, nil
|
||||
}
|
||||
|
||||
// ListDocumentShares lists all shares for a document
|
||||
func (s *PostgresStore) ListDocumentShares(ctx context.Context, documentID uuid.UUID) ([]models.DocumentShareWithUser, error) {
|
||||
query := `
|
||||
SELECT
|
||||
ds.id, ds.document_id, ds.user_id, ds.permission, ds.created_at, ds.created_by,
|
||||
u.id, u.email, u.name, u.avatar_url, u.provider, u.provider_user_id, u.created_at, u.updated_at, u.last_login_at
|
||||
FROM document_shares ds
|
||||
JOIN users u ON ds.user_id = u.id
|
||||
WHERE ds.document_id = $1
|
||||
ORDER BY ds.created_at DESC
|
||||
`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, documentID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var shares []models.DocumentShareWithUser
|
||||
for rows.Next() {
|
||||
var share models.DocumentShareWithUser
|
||||
err := rows.Scan(
|
||||
&share.ID, &share.DocumentID, &share.UserID, &share.Permission, &share.CreatedAt, &share.CreatedBy,
|
||||
&share.User.ID, &share.User.Email, &share.User.Name, &share.User.AvatarURL, &share.User.Provider,
|
||||
&share.User.ProviderUserID, &share.User.CreatedAt, &share.User.UpdatedAt, &share.User.LastLoginAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shares = append(shares, share)
|
||||
}
|
||||
|
||||
return shares, nil
|
||||
}
|
||||
|
||||
// DeleteDocumentShare deletes a share
|
||||
func (s *PostgresStore) DeleteDocumentShare(ctx context.Context, documentID, userID uuid.UUID) error {
|
||||
_, err := s.db.ExecContext(ctx, "DELETE FROM document_shares WHERE document_id = $1 AND user_id = $2", documentID, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
// CanViewDocument checks if user can view document (owner OR has any share)
|
||||
func (s *PostgresStore) CanViewDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
|
||||
UNION
|
||||
SELECT 1 FROM document_shares WHERE document_id = $1 AND user_id = $2
|
||||
)
|
||||
`
|
||||
|
||||
var canView bool
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canView)
|
||||
return canView, err
|
||||
}
|
||||
|
||||
// CanEditDocument checks if user can edit document (owner OR has edit share)
|
||||
func (s *PostgresStore) CanEditDocument(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM documents WHERE id = $1 AND owner_id = $2
|
||||
UNION
|
||||
SELECT 1 FROM document_shares WHERE document_id = $1 AND user_id = $2 AND permission = 'edit'
|
||||
)
|
||||
`
|
||||
|
||||
var canEdit bool
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&canEdit)
|
||||
return canEdit, err
|
||||
}
|
||||
|
||||
// IsDocumentOwner checks if user is the owner
|
||||
func (s *PostgresStore) IsDocumentOwner(ctx context.Context, documentID, userID uuid.UUID) (bool, error) {
|
||||
query := `SELECT owner_id = $2 FROM documents WHERE id = $1`
|
||||
|
||||
var isOwner bool
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, userID).Scan(&isOwner)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return isOwner, err
|
||||
}
|
||||
func (s *PostgresStore) GenerateShareToken(ctx context.Context, documentID uuid.UUID, permission string) (string, error) {
|
||||
// Generate random 32-byte token
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(tokenBytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %w", err)
|
||||
}
|
||||
token := base64.URLEncoding.EncodeToString(tokenBytes)
|
||||
|
||||
// Update document with share token
|
||||
query := `
|
||||
UPDATE documents
|
||||
SET share_token = $1, is_public = true, updated_at = NOW()
|
||||
WHERE id = $2
|
||||
RETURNING share_token
|
||||
`
|
||||
|
||||
var shareToken string
|
||||
err := s.db.QueryRowContext(ctx, query, token, documentID).Scan(&shareToken)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to set share token: %w", err)
|
||||
}
|
||||
|
||||
return shareToken, nil
|
||||
}
|
||||
|
||||
// ValidateShareToken checks if a share token is valid for a document
|
||||
func (s *PostgresStore) ValidateShareToken(ctx context.Context, documentID uuid.UUID, token string) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM documents
|
||||
WHERE id = $1 AND share_token = $2 AND is_public = true
|
||||
)
|
||||
`
|
||||
|
||||
var exists bool
|
||||
err := s.db.QueryRowContext(ctx, query, documentID, token).Scan(&exists)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to validate share token: %w", err)
|
||||
}
|
||||
|
||||
return exists, nil
|
||||
}
|
||||
|
||||
// RevokeShareToken removes the public share link from a document
|
||||
func (s *PostgresStore) RevokeShareToken(ctx context.Context, documentID uuid.UUID) error {
|
||||
query := `
|
||||
UPDATE documents
|
||||
SET share_token = NULL, is_public = false, updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
|
||||
_, err := s.db.ExecContext(ctx, query, documentID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke share token: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetShareToken retrieves the current share token for a document (if exists)
|
||||
func (s *PostgresStore) GetShareToken(ctx context.Context, documentID uuid.UUID) (string, bool, error) {
|
||||
query := `
|
||||
SELECT share_token FROM documents
|
||||
WHERE id = $1 AND is_public = true AND share_token IS NOT NULL
|
||||
`
|
||||
|
||||
var token string
|
||||
err := s.db.QueryRowContext(ctx, query, documentID).Scan(&token)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", false, fmt.Errorf("failed to get share token: %w", err)
|
||||
}
|
||||
|
||||
return token, true, nil
|
||||
}
|
||||
82
backend/internal/store/user.go
Normal file
82
backend/internal/store/user.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UpsertUser creates or updates user from OAuth profile
|
||||
func (s *PostgresStore) UpsertUser(ctx context.Context, provider, providerUserID, email, name string, avatarURL *string) (*models.User, error) {
|
||||
query := `
|
||||
INSERT INTO users (provider, provider_user_id, email, name, avatar_url, last_login_at)
|
||||
VALUES ($1, $2, $3, $4, $5, NOW())
|
||||
ON CONFLICT (provider, provider_user_id)
|
||||
DO UPDATE SET
|
||||
email = EXCLUDED.email,
|
||||
name = EXCLUDED.name,
|
||||
avatar_url = EXCLUDED.avatar_url,
|
||||
last_login_at = NOW(),
|
||||
updated_at = NOW()
|
||||
RETURNING id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
||||
`
|
||||
|
||||
var user models.User
|
||||
err := s.db.QueryRowContext(ctx, query, provider, providerUserID, email, name, avatarURL).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fmt.Printf("✅ User Upserted: ID=%s, Email=%s\n", user.ID.String(), user.Email)
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByID retrieves user by ID
|
||||
func (s *PostgresStore) GetUserByID(ctx context.Context, userID uuid.UUID) (*models.User, error) {
|
||||
query := `
|
||||
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
||||
FROM users WHERE id = $1
|
||||
`
|
||||
|
||||
var user models.User
|
||||
err := s.db.QueryRowContext(ctx, query, userID).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserByEmail retrieves user by email
|
||||
func (s *PostgresStore) GetUserByEmail(ctx context.Context, email string) (*models.User, error) {
|
||||
query := `
|
||||
SELECT id, email, name, avatar_url, provider, provider_user_id, created_at, updated_at, last_login_at
|
||||
FROM users WHERE email = $1
|
||||
`
|
||||
|
||||
var user models.User
|
||||
err := s.db.QueryRowContext(ctx, query, email).Scan(
|
||||
&user.ID, &user.Email, &user.Name, &user.AvatarURL, &user.Provider,
|
||||
&user.ProviderUserID, &user.CreatedAt, &user.UpdatedAt, &user.LastLoginAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
52
backend/scripts/001_add_users_and_sessions.sql
Normal file
52
backend/scripts/001_add_users_and_sessions.sql
Normal file
@@ -0,0 +1,52 @@
|
||||
-- Migration: Add users and sessions tables for authentication
|
||||
-- Run this before 002_add_document_shares.sql
|
||||
|
||||
-- Enable UUID extension
|
||||
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
|
||||
|
||||
-- Users table
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
email VARCHAR(255) NOT NULL,
|
||||
name VARCHAR(255) NOT NULL,
|
||||
avatar_url TEXT,
|
||||
provider VARCHAR(50) NOT NULL CHECK (provider IN ('google', 'github')),
|
||||
provider_user_id VARCHAR(255) NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
last_login_at TIMESTAMPTZ,
|
||||
UNIQUE(provider, provider_user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_users_email ON users(email);
|
||||
CREATE INDEX idx_users_provider ON users(provider, provider_user_id);
|
||||
|
||||
COMMENT ON TABLE users IS 'Stores user accounts from OAuth providers';
|
||||
COMMENT ON COLUMN users.provider IS 'OAuth provider: google or github';
|
||||
COMMENT ON COLUMN users.provider_user_id IS 'User ID from OAuth provider';
|
||||
|
||||
-- Sessions table
|
||||
CREATE TABLE IF NOT EXISTS sessions (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
token_hash VARCHAR(64) NOT NULL,
|
||||
expires_at TIMESTAMPTZ NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
user_agent TEXT,
|
||||
ip_address VARCHAR(45),
|
||||
UNIQUE(token_hash)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
|
||||
CREATE INDEX idx_sessions_token_hash ON sessions(token_hash);
|
||||
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
|
||||
|
||||
COMMENT ON TABLE sessions IS 'Stores active JWT sessions for revocation support';
|
||||
COMMENT ON COLUMN sessions.token_hash IS 'SHA-256 hash of JWT token';
|
||||
COMMENT ON COLUMN sessions.user_agent IS 'User agent string for device tracking';
|
||||
|
||||
-- Add owner_id to documents table if it doesn't exist
|
||||
ALTER TABLE documents ADD COLUMN IF NOT EXISTS owner_id UUID REFERENCES users(id) ON DELETE SET NULL;
|
||||
CREATE INDEX IF NOT EXISTS idx_documents_owner_id ON documents(owner_id);
|
||||
|
||||
COMMENT ON COLUMN documents.owner_id IS 'User who created the document';
|
||||
19
backend/scripts/002_add_document_shares.sql
Normal file
19
backend/scripts/002_add_document_shares.sql
Normal file
@@ -0,0 +1,19 @@
|
||||
-- Migration: Add document sharing with permissions
|
||||
-- Run against existing database
|
||||
|
||||
CREATE TABLE IF NOT EXISTS document_shares (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
permission VARCHAR(20) NOT NULL CHECK (permission IN ('view', 'edit')),
|
||||
created_at TIMESTAMPTZ DEFAULT NOW(),
|
||||
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
UNIQUE(document_id, user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_shares_document_id ON document_shares(document_id);
|
||||
CREATE INDEX idx_shares_user_id ON document_shares(user_id);
|
||||
CREATE INDEX idx_shares_permission ON document_shares(document_id, permission);
|
||||
|
||||
COMMENT ON TABLE document_shares IS 'Stores per-user document access permissions';
|
||||
COMMENT ON COLUMN document_shares.permission IS 'Access level: view (read-only) or edit (read-write)';
|
||||
@@ -1,10 +1,10 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import {
|
||||
createYjsDocument,
|
||||
destroyYjsDocument,
|
||||
getRandomColor,
|
||||
getRandomName,
|
||||
type YjsProviders,
|
||||
createYjsDocument,
|
||||
destroyYjsDocument,
|
||||
getRandomColor,
|
||||
getRandomName,
|
||||
type YjsProviders,
|
||||
} from "../lib/yjs";
|
||||
import { useAutoSave } from "./useAutoSave";
|
||||
|
||||
@@ -19,7 +19,6 @@ export const useYjsDocument = (documentId: string) => {
|
||||
let mounted = true;
|
||||
let currentProviders: YjsProviders | null = null;
|
||||
|
||||
// Create Yjs document and providers
|
||||
const initializeDocument = async () => {
|
||||
const yjsProviders = await createYjsDocument(documentId);
|
||||
currentProviders = yjsProviders;
|
||||
@@ -30,19 +29,75 @@ export const useYjsDocument = (documentId: string) => {
|
||||
}
|
||||
|
||||
// Set user info for awareness
|
||||
const userName = getRandomName();
|
||||
const userColor = getRandomColor();
|
||||
yjsProviders.awareness.setLocalStateField("user", {
|
||||
name: getRandomName(),
|
||||
color: getRandomColor(),
|
||||
name: userName,
|
||||
color: userColor,
|
||||
});
|
||||
|
||||
// NEW: Add awareness event logging
|
||||
const handleAwarenessChange = ({
|
||||
added,
|
||||
updated,
|
||||
removed,
|
||||
}: {
|
||||
added: number[];
|
||||
updated: number[];
|
||||
removed: number[];
|
||||
}) => {
|
||||
const states = yjsProviders.awareness.getStates();
|
||||
|
||||
added.forEach((clientId) => {
|
||||
const state = states.get(clientId);
|
||||
const user = state?.user;
|
||||
console.log(
|
||||
`[Awareness] User connected: ${
|
||||
user?.name || "Unknown"
|
||||
} (ID: ${clientId})`,
|
||||
{
|
||||
color: user?.color,
|
||||
clientId,
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
updated.forEach((clientId) => {
|
||||
const state = states.get(clientId);
|
||||
const user = state?.user;
|
||||
console.log(
|
||||
`[Awareness] User updated: ${
|
||||
user?.name || "Unknown"
|
||||
} (ID: ${clientId})`
|
||||
);
|
||||
});
|
||||
|
||||
removed.forEach((clientId) => {
|
||||
console.log(`[Awareness] User disconnected (ID: ${clientId})`);
|
||||
});
|
||||
|
||||
console.log(`[Awareness] Total connected users: ${states.size}`);
|
||||
};
|
||||
|
||||
yjsProviders.awareness.on("change", handleAwarenessChange);
|
||||
|
||||
// Listen for sync status
|
||||
yjsProviders.indexeddbProvider.on("synced", () => {
|
||||
console.log("IndexedDB synced");
|
||||
setSynced(true);
|
||||
});
|
||||
|
||||
yjsProviders.websocketProvider.on("status", (event: { status: string }) => {
|
||||
console.log("WebSocket status:", event.status);
|
||||
yjsProviders.websocketProvider.on(
|
||||
"status",
|
||||
(event: { status: string }) => {
|
||||
console.log("WebSocket status:", event.status);
|
||||
}
|
||||
);
|
||||
|
||||
// Log local user info
|
||||
console.log(`[Awareness] Local user initialized: ${userName}`, {
|
||||
color: userColor,
|
||||
clientId: yjsProviders.awareness.clientID,
|
||||
});
|
||||
|
||||
setProviders(yjsProviders);
|
||||
@@ -54,10 +109,13 @@ export const useYjsDocument = (documentId: string) => {
|
||||
return () => {
|
||||
mounted = false;
|
||||
if (currentProviders) {
|
||||
console.log("[Awareness] Cleaning up local user");
|
||||
currentProviders.awareness.setLocalState(null);
|
||||
destroyYjsDocument(currentProviders);
|
||||
}
|
||||
};
|
||||
}, [documentId]);
|
||||
|
||||
|
||||
return { providers, synced };
|
||||
};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { IndexeddbPersistence } from "y-indexeddb";
|
||||
import { Awareness } from "y-protocols/awareness";
|
||||
import { WebsocketProvider } from "y-websocket";
|
||||
import * as Y from "yjs";
|
||||
import { documentsApi } from "../api/document";
|
||||
@@ -9,7 +10,7 @@ export interface YjsProviders {
|
||||
ydoc: Y.Doc;
|
||||
websocketProvider: WebsocketProvider;
|
||||
indexeddbProvider: IndexeddbPersistence;
|
||||
awareness: any;
|
||||
awareness: Awareness;
|
||||
}
|
||||
|
||||
export const createYjsDocument = async (documentId: string): Promise<YjsProviders> => {
|
||||
|
||||
Reference in New Issue
Block a user