Compare commits

..

9 Commits

Author SHA1 Message Date
M1ngdaXie
afb04e5cd3 feat: migrate realtime-collab from Docker Compose to k3s
Add k3s manifests for postgres, redis, and backend

Fix users table constraint and init.sql
2026-03-25 01:19:00 +00:00
M1ngdaXie
9c19769eb0 feat: add guest mode, bug fixes, and self-hosted config
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-15 09:45:17 +00:00
M1ngdaXie
763575f284 fix: restore original URL after OAuth login redirect
Save the intended destination to sessionStorage before navigating to
the OAuth provider, and read it back in AuthCallback after login.
Also handles 401-triggered redirects so session-expired users are
returned to the page they were on.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-09 04:42:19 +00:00
M1ngdaXie
731bd67334 Add self-hosted deployment configuration
- Add backend entry point (cmd/server/main.go)
- Add prompt=select_account to Google OAuth flow
- Add combined init.sql for self-hosted PostgreSQL
- Update docker-compose to include backend service with memory limits

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-09 01:38:15 +00:00
M1ngdaXie
50822600ad feat: implement Redis Streams support with stream checkpoints and update history
- Added Redis Streams operations to the message bus interface and implementation.
- Introduced StreamCheckpoint model to track last processed stream entry per document.
- Implemented UpsertStreamCheckpoint and GetStreamCheckpoint methods in the Postgres store.
- Created document_update_history table for storing update payloads for recovery and replay.
- Developed update persist worker to handle Redis Stream updates and persist them to Postgres.
- Enhanced Docker Compose configuration for Redis with persistence.
- Updated frontend API to support fetching document state with optional share token.
- Added connection stability monitoring in the Yjs document hook.
2026-03-08 17:13:42 -07:00
M1ngdaXie
f319e8ec75 feat(kanban): implement task reordering and improve task movement logic
feat(share): add documentType prop to ShareModal for dynamic URL generation
2026-02-08 16:38:02 -08:00
M1ngdaXie
3179ead0a5 feat(assets): add docnest icons in various resolutions 2026-02-08 16:26:25 -08:00
M1ngdaXie
10fd9cdecb dark mode 2026-02-08 16:23:06 -08:00
M1ngdaXie
10110e26b3 remove unused variables in EditorPage and KanbanPage 2026-02-08 12:36:00 -08:00
61 changed files with 2830 additions and 154 deletions

8
.gitignore vendored
View File

@@ -35,3 +35,11 @@ build/
postgres_data/ postgres_data/
.claude/ .claude/
#test folder profiles
loadtest/pprof
/docs
# K3s secrets
k3s/secret.yaml

2
backend/.gitignore vendored
View File

@@ -4,7 +4,7 @@
.env.*.local .env.*.local
# Compiled binaries # Compiled binaries
server /server
*.exe *.exe
*.exe~ *.exe~
*.dll *.dll

282
backend/cmd/server/main.go Normal file
View File

@@ -0,0 +1,282 @@
package main
import (
"context"
"flag"
"fmt"
"log"
"net"
"net/http/pprof"
"os"
"strconv"
"strings"
"time"
"runtime"
"github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/config"
"github.com/M1ngdaXie/realtime-collab/internal/handlers"
"github.com/M1ngdaXie/realtime-collab/internal/hub"
"github.com/M1ngdaXie/realtime-collab/internal/logger"
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
"github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/M1ngdaXie/realtime-collab/internal/workers"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"go.uber.org/zap"
)
func main() {
// CLI flags - override env vars
portFlag := flag.String("port", "", "Server port (overrides PORT env var)")
flag.Parse()
// Load configuration
cfg, err := config.Load(*portFlag)
if err != nil {
log.Fatalf("Configuration error: %v", err)
}
log.Printf("Configuration loaded (environment: %s, port: %s)", cfg.Environment, cfg.Port)
// Initialize structured logger
zapLogger, err := logger.NewLoggerFromEnv()
if err != nil {
log.Fatalf("Failed to initialize logger: %v", err)
}
defer zapLogger.Sync()
// Generate unique server ID for this instance
hostname, _ := os.Hostname()
serverID := fmt.Sprintf("%s-%s", hostname, uuid.New().String()[:8])
zapLogger.Info("Server identity", zap.String("server_id", serverID))
// Initialize MessageBus (Redis or Local fallback)
var msgBus messagebus.MessageBus
if cfg.RedisURL != "" {
redisBus, err := messagebus.NewRedisMessageBus(cfg.RedisURL, serverID, zapLogger)
if err != nil {
zapLogger.Warn("Redis unavailable, falling back to local mode", zap.Error(err))
msgBus = messagebus.NewLocalMessageBus()
} else {
msgBus = redisBus
}
} else {
zapLogger.Info("No REDIS_URL configured, using local mode")
msgBus = messagebus.NewLocalMessageBus()
}
defer msgBus.Close()
// Initialize database
dbStore, err := store.NewPostgresStore(cfg.DatabaseURL)
if err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer dbStore.Close()
log.Println("Database connection established")
// Initialize WebSocket hub
wsHub := hub.NewHub(msgBus, serverID, zapLogger)
go wsHub.Run()
zapLogger.Info("WebSocket hub started")
// Start Redis health monitoring (if using Redis)
if redisBus, ok := msgBus.(*messagebus.RedisMessageBus); ok {
go redisBus.StartHealthMonitoring(context.Background(), 30*time.Second, func(healthy bool) {
wsHub.SetFallbackMode(!healthy)
})
zapLogger.Info("Redis health monitoring started")
}
// Start update persist worker (stream WAL persistence)
workerCtx, workerCancel := context.WithCancel(context.Background())
defer workerCancel()
go workers.StartUpdatePersistWorker(workerCtx, msgBus, dbStore, zapLogger, serverID)
zapLogger.Info("Update persist worker started")
// Start periodic session cleanup (every hour)
go func() {
ticker := time.NewTicker(1 * time.Hour)
defer ticker.Stop()
for range ticker.C {
if err := dbStore.CleanupExpiredSessions(context.Background()); err != nil {
log.Printf("Error cleaning up expired sessions: %v", err)
} else {
log.Println("Cleaned up expired sessions")
}
}
}()
log.Println("Session cleanup task started")
// Initialize handlers
docHandler := handlers.NewDocumentHandler(dbStore, msgBus, serverID, zapLogger)
wsHandler := handlers.NewWebSocketHandler(wsHub, dbStore, cfg, msgBus)
authHandler := handlers.NewAuthHandler(dbStore, cfg)
authMiddleware := auth.NewAuthMiddleware(dbStore, cfg.JWTSecret, zapLogger)
shareHandler := handlers.NewShareHandler(dbStore, cfg)
versionHandler := handlers.NewVersionHandler(dbStore)
// Setup Gin router
router := gin.Default()
// Optional pprof endpoints for profiling under load (guarded by env).
// Enable with: ENABLE_PPROF=1
// Optional: PPROF_BLOCK_RATE=1 PPROF_MUTEX_FRACTION=1 (adds overhead; use for short profiling windows).
if shouldEnablePprof(cfg) {
blockRate := getEnvInt("PPROF_BLOCK_RATE", 0)
mutexFraction := getEnvInt("PPROF_MUTEX_FRACTION", 0)
localOnly := getEnvBool("PPROF_LOCAL_ONLY", true)
if blockRate > 0 {
runtime.SetBlockProfileRate(blockRate)
}
if mutexFraction > 0 {
runtime.SetMutexProfileFraction(mutexFraction)
}
pprofGroup := router.Group("/debug/pprof")
if localOnly {
pprofGroup.Use(func(c *gin.Context) {
ip := net.ParseIP(c.ClientIP())
if ip == nil || !ip.IsLoopback() {
c.AbortWithStatus(403)
return
}
c.Next()
})
}
user, pass := os.Getenv("PPROF_USER"), os.Getenv("PPROF_PASS")
if user != "" || pass != "" {
if user == "" || pass == "" {
zapLogger.Warn("PPROF_USER/PPROF_PASS must both be set; skipping basic auth")
} else {
pprofGroup.Use(gin.BasicAuth(gin.Accounts{user: pass}))
}
}
pprofGroup.GET("/", gin.WrapF(pprof.Index))
pprofGroup.GET("/cmdline", gin.WrapF(pprof.Cmdline))
pprofGroup.GET("/profile", gin.WrapF(pprof.Profile))
pprofGroup.GET("/symbol", gin.WrapF(pprof.Symbol))
pprofGroup.GET("/trace", gin.WrapF(pprof.Trace))
pprofGroup.GET("/allocs", gin.WrapH(pprof.Handler("allocs")))
pprofGroup.GET("/block", gin.WrapH(pprof.Handler("block")))
pprofGroup.GET("/goroutine", gin.WrapH(pprof.Handler("goroutine")))
pprofGroup.GET("/heap", gin.WrapH(pprof.Handler("heap")))
pprofGroup.GET("/mutex", gin.WrapH(pprof.Handler("mutex")))
pprofGroup.GET("/threadcreate", gin.WrapH(pprof.Handler("threadcreate")))
zapLogger.Info("pprof enabled",
zap.Bool("local_only", localOnly),
zap.Int("block_rate", blockRate),
zap.Int("mutex_fraction", mutexFraction),
)
}
// CORS configuration
corsConfig := cors.DefaultConfig()
corsConfig.AllowOrigins = cfg.AllowedOrigins
corsConfig.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
corsConfig.AllowHeaders = []string{"Origin", "Content-Type", "Accept", "Authorization"}
corsConfig.AllowCredentials = true
router.Use(cors.New(corsConfig))
// Health check
router.GET("/health", func(c *gin.Context) {
c.JSON(200, gin.H{"status": "ok"})
})
// WebSocket endpoint (no auth required, validated in handler)
router.GET("/ws/:roomId", wsHandler.HandleWebSocket)
// Load test endpoint - NO AUTH (only for local testing!)
router.GET("/ws/loadtest/:roomId", wsHandler.HandleWebSocketLoadTest)
// REST API
api := router.Group("/api")
authGroup := api.Group("/auth")
{
authGroup.GET("/google", authHandler.GoogleLogin)
authGroup.GET("/google/callback", authHandler.GoogleCallback)
authGroup.GET("/github", authHandler.GithubLogin)
authGroup.GET("/github/callback", authHandler.GithubCallback)
authGroup.POST("/guest", authHandler.GuestLogin)
authGroup.GET("/me", authMiddleware.RequireAuth(), authHandler.Me)
authGroup.POST("/logout", authMiddleware.RequireAuth(), authHandler.Logout)
}
// Document routes with optional auth
docs := api.Group("/documents")
{
docs.GET("", authMiddleware.RequireAuth(), docHandler.ListDocuments)
docs.GET("/:id", authMiddleware.RequireAuth(), docHandler.GetDocument)
docs.GET("/:id/state", authMiddleware.OptionalAuth(), docHandler.GetDocumentState)
// Permission route (supports both auth and share token)
docs.GET("/:id/permission", authMiddleware.OptionalAuth(), docHandler.GetDocumentPermission)
docs.POST("", authMiddleware.RequireAuth(), docHandler.CreateDocument)
docs.PUT("/:id/state", authMiddleware.RequireAuth(), docHandler.UpdateDocumentState)
docs.DELETE("/:id", authMiddleware.RequireAuth(), docHandler.DeleteDocument)
// Share routes
docs.POST("/:id/shares", authMiddleware.RequireAuth(), shareHandler.CreateShare)
docs.GET("/:id/shares", authMiddleware.RequireAuth(), shareHandler.ListShares)
docs.DELETE("/:id/shares/:userId", authMiddleware.RequireAuth(), shareHandler.DeleteShare)
docs.POST("/:id/share-link", authMiddleware.RequireAuth(), shareHandler.CreateShareLink)
docs.GET("/:id/share-link", authMiddleware.RequireAuth(), shareHandler.GetShareLink)
docs.DELETE("/:id/share-link", authMiddleware.RequireAuth(), shareHandler.RevokeShareLink)
// Version history routes
docs.POST("/:id/versions", authMiddleware.RequireAuth(), versionHandler.CreateVersion)
docs.GET("/:id/versions", authMiddleware.RequireAuth(), versionHandler.ListVersions)
docs.GET("/:id/versions/:versionId/snapshot", authMiddleware.RequireAuth(), versionHandler.GetVersionSnapshot)
docs.POST("/:id/restore", authMiddleware.RequireAuth(), versionHandler.RestoreVersion)
}
// Start server
log.Printf("Server starting on port %s", cfg.Port)
if err := router.Run(":" + cfg.Port); err != nil {
log.Fatalf("Failed to start server: %v", err)
}
}
func shouldEnablePprof(cfg *config.Config) bool {
if cfg == nil || cfg.IsProduction() {
return false
}
return getEnvBool("ENABLE_PPROF", false)
}
func getEnvBool(key string, defaultValue bool) bool {
value, ok := os.LookupEnv(key)
if !ok {
return defaultValue
}
switch strings.ToLower(strings.TrimSpace(value)) {
case "1", "true", "t", "yes", "y", "on":
return true
case "0", "false", "f", "no", "n", "off":
return false
default:
return defaultValue
}
}
func getEnvInt(key string, defaultValue int) int {
value, ok := os.LookupEnv(key)
if !ok {
return defaultValue
}
parsed, err := strconv.Atoi(strings.TrimSpace(value))
if err != nil {
return defaultValue
}
return parsed
}

View File

@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap"
) )
type contextKey string type contextKey string
@@ -20,41 +21,40 @@ const ContextUserIDKey = "user_id"
type AuthMiddleware struct { type AuthMiddleware struct {
store store.Store store store.Store
jwtSecret string jwtSecret string
logger *zap.Logger
} }
// NewAuthMiddleware creates a new auth middleware // NewAuthMiddleware creates a new auth middleware
func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware { func NewAuthMiddleware(store store.Store, jwtSecret string, logger *zap.Logger) *AuthMiddleware {
if logger == nil {
logger = zap.NewNop()
}
return &AuthMiddleware{ return &AuthMiddleware{
store: store, store: store,
jwtSecret: jwtSecret, jwtSecret: jwtSecret,
logger: logger,
} }
} }
// RequireAuth middleware requires valid authentication // RequireAuth middleware requires valid authentication
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc { func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
fmt.Println("🔒 RequireAuth: Starting authentication check")
user, claims, err := m.getUserFromToken(c) user, claims, err := m.getUserFromToken(c)
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
if claims != nil {
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
}
if err != nil || user == nil { if err != nil || user == nil {
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user) if err != nil {
m.logger.Warn("auth failed",
zap.Error(err),
zap.String("method", c.Request.Method),
zap.String("path", c.FullPath()),
)
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort() c.Abort()
return return
} }
// Note: Name and Email might be empty for old JWT tokens // Note: Name and Email might be empty for old JWT tokens
if claims.Name == "" || claims.Email == "" {
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
}
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
c.Set(ContextUserIDKey, user) c.Set(ContextUserIDKey, user)
c.Set("user_email", claims.Email) c.Set("user_email", claims.Email)
c.Set("user_name", claims.Name) c.Set("user_name", claims.Name)
@@ -88,21 +88,17 @@ func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error) // 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) { func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
if authHeader == "" { if authHeader == "" {
fmt.Println("⚠️ getUserFromToken: No Authorization header")
return nil, nil, nil return nil, nil, nil
} }
parts := strings.Split(authHeader, " ") parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" { if len(parts) != 2 || parts[0] != "Bearer" {
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
return nil, nil, nil return nil, nil, nil
} }
tokenString := parts[1] tokenString := parts[1]
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
// 必须要验证签名算法是 HMAC (HS256) // 必须要验证签名算法是 HMAC (HS256)
@@ -113,7 +109,6 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
}) })
if err != nil { if err != nil {
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
return nil, nil, err return nil, nil, err
} }
@@ -123,17 +118,14 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String() // 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
userID, err := uuid.Parse(claims.Subject) userID, err := uuid.Parse(claims.Subject)
if err != nil { if err != nil {
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
return nil, nil, fmt.Errorf("invalid user ID in token") return nil, nil, fmt.Errorf("invalid user ID in token")
} }
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email) // 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
// 这一步完全没有查数据库,速度极快 // 这一步完全没有查数据库,速度极快
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
return &userID, claims, nil return &userID, claims, nil
} }
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
return nil, nil, fmt.Errorf("invalid token claims") return nil, nil, fmt.Errorf("invalid token claims")
} }
@@ -141,8 +133,6 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
func GetUserFromContext(c *gin.Context) *uuid.UUID { func GetUserFromContext(c *gin.Context) *uuid.UUID {
// 修正点:使用和存入时完全一样的 Key // 修正点:使用和存入时完全一样的 Key
val, exists := c.Get(ContextUserIDKey) val, exists := c.Get(ContextUserIDKey)
fmt.Println("within getFromContext the id is ... ")
fmt.Println(val)
if !exists { if !exists {
return nil return nil
} }

View File

@@ -6,7 +6,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log"
"net/http" "net/http"
"time" "time"
@@ -57,7 +57,7 @@ func NewAuthHandler(store store.Store, cfg *config.Config) *AuthHandler {
func (h *AuthHandler) GoogleLogin(c *gin.Context) { func (h *AuthHandler) GoogleLogin(c *gin.Context) {
// Generate random state and set cookie // Generate random state and set cookie
oauthState := h.generateStateOauthCookie(c.Writer) oauthState := h.generateStateOauthCookie(c.Writer)
url := h.googleConfig.AuthCodeURL(oauthState, oauth2.AccessTypeOffline) url := h.googleConfig.AuthCodeURL(oauthState, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "select_account"))
c.Redirect(http.StatusTemporaryRedirect, url) c.Redirect(http.StatusTemporaryRedirect, url)
} }
@@ -68,7 +68,7 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"}) c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return return
} }
log.Println("Google callback state:", c.Query("state"))
// Exchange code for token // Exchange code for token
token, err := h.googleConfig.Exchange(c.Request.Context(), c.Query("code")) token, err := h.googleConfig.Exchange(c.Request.Context(), c.Query("code"))
if err != nil { if err != nil {
@@ -83,8 +83,7 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to get user info"})
return return
} }
log.Println("Google user info response status:", resp.Status)
log.Println("Google user info response headers:", resp.Header)
defer resp.Body.Close() defer resp.Body.Close()
data, _ := io.ReadAll(resp.Body) data, _ := io.ReadAll(resp.Body)
@@ -96,11 +95,11 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
} }
if err := json.Unmarshal(data, &userInfo); err != nil { if err := json.Unmarshal(data, &userInfo); err != nil {
log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data)) // Failed to parse Google response
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"})
return return
} }
log.Println("Google user info:", userInfo)
// Upsert user in database // Upsert user in database
user, err := h.store.UpsertUser( user, err := h.store.UpsertUser(
c.Request.Context(), c.Request.Context(),
@@ -116,12 +115,9 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
} }
// Create session and JWT // Create session and JWT
jwt, err := h.createSessionAndJWT(c, user) jwt, err := h.createSessionAndJWT(c, user, 7*24*time.Hour)
if err != nil { if err != nil {
fmt.Printf("❌ DATABASE ERROR: %v\n", err) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"})
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("CreateSession Error: %v", err),
})
return return
} }
@@ -144,7 +140,7 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"}) c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return return
} }
log.Println("Github callback state:", c.Query("state"))
code := c.Query("code") code := c.Query("code")
if code == "" { if code == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "No code provided"}) c.JSON(http.StatusBadRequest, gin.H{"error": "No code provided"})
@@ -178,7 +174,7 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
AvatarURL string `json:"avatar_url"` AvatarURL string `json:"avatar_url"`
} }
if err := json.Unmarshal(data, &userInfo); err != nil { if err := json.Unmarshal(data, &userInfo); err != nil {
log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data)) // Failed to parse GitHub response
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"})
return return
} }
@@ -207,8 +203,7 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
if userInfo.Name == "" { if userInfo.Name == "" {
userInfo.Name = userInfo.Login userInfo.Name = userInfo.Login
} }
fmt.Println("Getting user info : ")
fmt.Println(userInfo)
// Upsert user in database // Upsert user in database
user, err := h.store.UpsertUser( user, err := h.store.UpsertUser(
c.Request.Context(), c.Request.Context(),
@@ -224,7 +219,7 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
} }
// Create session and JWT // Create session and JWT
jwt, err := h.createSessionAndJWT(c, user) jwt, err := h.createSessionAndJWT(c, user, 7*24*time.Hour)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"})
return return
@@ -273,12 +268,47 @@ func (h *AuthHandler) Logout(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"}) c.JSON(http.StatusOK, gin.H{"message": "Logged out successfully"})
} }
// GuestLogin creates a temporary guest user and returns a JWT
func (h *AuthHandler) GuestLogin(c *gin.Context) {
// Generate random 4-byte hex string for guest ID
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to generate guest ID"})
return
}
guestHex := fmt.Sprintf("%x", b)
guestName := fmt.Sprintf("Guest-%s", guestHex)
guestEmail := fmt.Sprintf("guest-%s@guest.local", guestHex)
providerUserID := uuid.New().String()
user, err := h.store.UpsertUser(
c.Request.Context(),
"guest",
providerUserID,
guestEmail,
guestName,
nil,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create guest user"})
return
}
jwt, err := h.createSessionAndJWT(c, user, 24*time.Hour)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to create session"})
return
}
c.JSON(http.StatusOK, gin.H{"token": jwt})
}
// Helper: create session and JWT // Helper: create session and JWT
func (h *AuthHandler) createSessionAndJWT(c *gin.Context, user *models.User) (string, error) { func (h *AuthHandler) createSessionAndJWT(c *gin.Context, user *models.User, expiry time.Duration) (string, error) {
expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7 days expiresAt := time.Now().Add(expiry)
// Generate JWT first (we need it for session) - now includes avatar URL // Generate JWT first (we need it for session) - now includes avatar URL
jwt, err := auth.GenerateJWT(user.ID, user.Name, user.Email, user.AvatarURL, h.cfg.JWTSecret, 7*24*time.Hour) jwt, err := auth.GenerateJWT(user.ID, user.Name, user.Email, user.AvatarURL, h.cfg.JWTSecret, expiry)
if err != nil { if err != nil {
return "", err return "", err
} }
@@ -306,7 +336,7 @@ func (h *AuthHandler) generateStateOauthCookie(w http.ResponseWriter) string {
b := make([]byte, 16) b := make([]byte, 16)
n, err := rand.Read(b) n, err := rand.Read(b)
if err != nil || n != 16 { if err != nil || n != 16 {
fmt.Printf("Failed to generate random state: %v\n", err) // Failed to generate random state
return "" // Critical for CSRF security return "" // Critical for CSRF security
} }
state := base64.URLEncoding.EncodeToString(b) state := base64.URLEncoding.EncodeToString(b)

View File

@@ -1,22 +1,33 @@
package handlers package handlers
import ( import (
"fmt" "context"
"net/http" "net/http"
"time"
"github.com/M1ngdaXie/realtime-collab/internal/auth" "github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
"github.com/M1ngdaXie/realtime-collab/internal/models" "github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/M1ngdaXie/realtime-collab/internal/store" "github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap"
) )
type DocumentHandler struct { type DocumentHandler struct {
store *store.PostgresStore store *store.PostgresStore
messageBus messagebus.MessageBus
serverID string
logger *zap.Logger
} }
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler { func NewDocumentHandler(s *store.PostgresStore, msgBus messagebus.MessageBus, serverID string, logger *zap.Logger) *DocumentHandler {
return &DocumentHandler{store: s} return &DocumentHandler{
store: s,
messageBus: msgBus,
serverID: serverID,
logger: logger,
}
} }
// CreateDocument creates a new document (requires auth) // CreateDocument creates a new document (requires auth)
@@ -45,8 +56,6 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
func (h *DocumentHandler) ListDocuments(c *gin.Context) { func (h *DocumentHandler) ListDocuments(c *gin.Context) {
userID := auth.GetUserFromContext(c) userID := auth.GetUserFromContext(c)
fmt.Println("Getting userId, which is : ")
fmt.Println(userID)
if userID == nil { if userID == nil {
respondUnauthorized(c, "Authentication required to list documents") respondUnauthorized(c, "Authentication required to list documents")
return return
@@ -113,6 +122,13 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
} }
userID := auth.GetUserFromContext(c) userID := auth.GetUserFromContext(c)
shareToken := c.Query("share")
doc, err := h.store.GetDocument(id)
if err != nil {
respondNotFound(c, "document")
return
}
// Check permission if authenticated // Check permission if authenticated
if userID != nil { if userID != nil {
@@ -125,12 +141,22 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
respondForbidden(c, "Access denied") respondForbidden(c, "Access denied")
return return
} }
} } else {
// Unauthenticated: require valid share token or public doc
doc, err := h.store.GetDocument(id) if shareToken != "" {
if err != nil { valid, err := h.store.ValidateShareToken(c.Request.Context(), id, shareToken)
respondNotFound(c, "document") if err != nil {
return respondInternalError(c, "Failed to validate share token", err)
return
}
if !valid {
respondForbidden(c, "Invalid or expired share token")
return
}
} else if !doc.Is_Public {
respondForbidden(c, "This document is not public. Please sign in to access.")
return
}
} }
// Return empty byte slice if state is nil (new document) // Return empty byte slice if state is nil (new document)
@@ -191,6 +217,16 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
return return
} }
if streamID, seq, ok := h.addSnapshotMarker(c.Request.Context(), id); ok {
if err := h.store.UpsertStreamCheckpoint(c.Request.Context(), id, streamID, seq); err != nil {
if h.logger != nil {
h.logger.Warn("Failed to upsert stream checkpoint after snapshot",
zap.String("document_id", id.String()),
zap.Error(err))
}
}
}
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"}) c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
} }
@@ -234,6 +270,43 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"}) c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"})
} }
func (h *DocumentHandler) addSnapshotMarker(ctx context.Context, documentID uuid.UUID) (string, int64, bool) {
if h.messageBus == nil {
return "", 0, false
}
streamKey := "stream:" + documentID.String()
seqKey := "seq:" + documentID.String()
seq, err := h.messageBus.Incr(ctx, seqKey)
if err != nil {
if h.logger != nil {
h.logger.Warn("Failed to increment snapshot sequence",
zap.String("document_id", documentID.String()),
zap.Error(err))
}
return "", 0, false
}
values := map[string]interface{}{
"type": "snapshot",
"server_id": h.serverID,
"seq": seq,
"timestamp": time.Now().Format(time.RFC3339),
}
streamID, err := h.messageBus.XAdd(ctx, streamKey, 10000, true, values)
if err != nil {
if h.logger != nil {
h.logger.Warn("Failed to add snapshot marker to stream",
zap.String("stream_key", streamKey),
zap.Error(err))
}
return "", 0, false
}
return streamID, seq, true
}
// GetDocumentPermission returns the user's permission level for a document // GetDocumentPermission returns the user's permission level for a document
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) { func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id")) documentID, err := uuid.Parse(c.Param("id"))

View File

@@ -7,10 +7,12 @@ import (
"testing" "testing"
"github.com/M1ngdaXie/realtime-collab/internal/auth" "github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
"github.com/M1ngdaXie/realtime-collab/internal/models" "github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"go.uber.org/zap"
) )
// DocumentHandlerSuite tests document CRUD operations // DocumentHandlerSuite tests document CRUD operations
@@ -23,7 +25,7 @@ type DocumentHandlerSuite struct {
// SetupTest runs before each test // SetupTest runs before each test
func (s *DocumentHandlerSuite) SetupTest() { func (s *DocumentHandlerSuite) SetupTest() {
s.BaseHandlerSuite.SetupTest() s.BaseHandlerSuite.SetupTest()
s.handler = NewDocumentHandler(s.store) s.handler = NewDocumentHandler(s.store, messagebus.NewLocalMessageBus(), "test-server", zap.NewNop())
s.setupRouter() s.setupRouter()
} }

View File

@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"go.uber.org/zap"
) )
// ShareHandlerSuite tests for share handler endpoints // ShareHandlerSuite tests for share handler endpoints
@@ -24,7 +25,7 @@ func (s *ShareHandlerSuite) SetupTest() {
s.BaseHandlerSuite.SetupTest() s.BaseHandlerSuite.SetupTest()
// Create handler and router // Create handler and router
authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret) authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret, zap.NewNop())
s.handler = NewShareHandler(s.store, s.cfg) s.handler = NewShareHandler(s.store, s.cfg)
s.router = gin.New() s.router = gin.New()

View File

@@ -1,13 +1,18 @@
package handlers package handlers
import ( import (
"context"
"encoding/base64"
"fmt"
"log" "log"
"net/http" "net/http"
"strconv"
"time" "time"
"github.com/M1ngdaXie/realtime-collab/internal/auth" "github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/config" "github.com/M1ngdaXie/realtime-collab/internal/config"
"github.com/M1ngdaXie/realtime-collab/internal/hub" "github.com/M1ngdaXie/realtime-collab/internal/hub"
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
"github.com/M1ngdaXie/realtime-collab/internal/store" "github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
@@ -19,16 +24,18 @@ import (
var connectionSem = make(chan struct{}, 200) var connectionSem = make(chan struct{}, 200)
type WebSocketHandler struct { type WebSocketHandler struct {
hub *hub.Hub hub *hub.Hub
store store.Store store store.Store
cfg *config.Config cfg *config.Config
msgBus messagebus.MessageBus
} }
func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config) *WebSocketHandler { func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config, msgBus messagebus.MessageBus) *WebSocketHandler {
return &WebSocketHandler{ return &WebSocketHandler{
hub: h, hub: h,
store: s, store: s,
cfg: cfg, cfg: cfg,
msgBus: msgBus,
} }
} }
@@ -101,7 +108,7 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
// Validate share token // Validate share token
valid, err := wsh.store.ValidateShareToken(c.Request.Context(), documentID, shareToken) valid, err := wsh.store.ValidateShareToken(c.Request.Context(), documentID, shareToken)
if err != nil { if err != nil {
log.Printf("Error validating share token: %v", err) // Error validating share token
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to validate share token"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to validate share token"})
return return
} }
@@ -127,7 +134,7 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
// Authenticated user - get their permission level // Authenticated user - get their permission level
perm, err := wsh.store.GetUserPermission(c.Request.Context(), documentID, *userID) perm, err := wsh.store.GetUserPermission(c.Request.Context(), documentID, *userID)
if err != nil { if err != nil {
log.Printf("Error getting user permission: %v", err) // Error getting user permission
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"})
return return
} }
@@ -140,7 +147,7 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
// Share token user - get share link permission // Share token user - get share link permission
perm, err := wsh.store.GetShareLinkPermission(c.Request.Context(), documentID) perm, err := wsh.store.GetShareLinkPermission(c.Request.Context(), documentID)
if err != nil { if err != nil {
log.Printf("Error getting share link permission: %v", err) // Error getting share link permission
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to check permissions"})
return return
} }
@@ -156,7 +163,7 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
upgrader := wsh.getUpgrader() upgrader := wsh.getUpgrader()
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil { if err != nil {
log.Printf("Failed to upgrade connection: %v", err) // Failed to upgrade WebSocket connection
return return
} }
@@ -170,6 +177,105 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
// Start goroutines // Start goroutines
go client.WritePump() go client.WritePump()
go client.ReadPump() go client.ReadPump()
go wsh.replayBacklog(client, documentID)
log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID) // Client connected
}
const maxReplayUpdates = 5000
func (wsh *WebSocketHandler) replayBacklog(client *hub.Client, documentID uuid.UUID) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
checkpoint, err := wsh.store.GetStreamCheckpoint(ctx, documentID)
if err != nil || checkpoint == nil || checkpoint.LastStreamID == "" {
return
}
streamKey := "stream:" + documentID.String()
var sent int
// Primary: Redis stream replay
if wsh.msgBus != nil {
messages, err := wsh.msgBus.XRange(ctx, streamKey, checkpoint.LastStreamID, "+")
if err == nil && len(messages) > 0 {
for _, msg := range messages {
if msg.ID == checkpoint.LastStreamID {
continue
}
if sent >= maxReplayUpdates {
log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String())
return
}
msgType := getString(msg.Values["type"])
if msgType != "update" {
continue
}
seq := parseInt64(msg.Values["seq"])
if seq <= checkpoint.LastSeq {
continue
}
payloadB64 := getString(msg.Values["yjs_payload"])
payload, err := base64.StdEncoding.DecodeString(payloadB64)
if err != nil {
continue
}
if client.Enqueue(payload) {
sent++
} else {
return
}
}
return
}
}
// Fallback: DB history replay
updates, err := wsh.store.ListUpdateHistoryAfterSeq(ctx, documentID, checkpoint.LastSeq, maxReplayUpdates)
if err != nil {
return
}
for _, upd := range updates {
if sent >= maxReplayUpdates {
log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String())
return
}
if client.Enqueue(upd.Payload) {
sent++
} else {
return
}
}
}
func getString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprint(v)
}
}
func parseInt64(value interface{}) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case uint64:
return int64(v)
case string:
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
return parsed
}
case []byte:
if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil {
return parsed
}
}
return 0
} }

View File

@@ -2,6 +2,8 @@ package hub
import ( import (
"context" "context"
"encoding/base64"
"strconv"
"sync" "sync"
"time" "time"
@@ -37,10 +39,11 @@ type Client struct {
idsMu sync.Mutex idsMu sync.Mutex
} }
type Room struct { type Room struct {
ID string ID string
clients map[*Client]bool clients map[*Client]bool
mu sync.RWMutex mu sync.RWMutex
cancel context.CancelFunc cancel context.CancelFunc
reconnectCount int // Track Redis reconnection attempts for debugging
} }
type Hub struct { type Hub struct {
@@ -64,6 +67,10 @@ type Hub struct {
// Bounded worker pool for Redis SetAwareness // Bounded worker pool for Redis SetAwareness
awarenessQueue chan awarenessItem awarenessQueue chan awarenessItem
// Stream persistence worker pool (P1: Redis Streams durability)
streamQueue chan *Message // buffered queue for XADD operations
streamDone chan struct{} // close to signal stream workers to exit
} }
const ( const (
@@ -79,6 +86,13 @@ const (
// awarenessQueueSize is the buffer size for awareness updates. // awarenessQueueSize is the buffer size for awareness updates.
awarenessQueueSize = 4096 awarenessQueueSize = 4096
// streamWorkerCount is the number of fixed goroutines consuming from streamQueue.
// 50 workers match publish workers for consistent throughput.
streamWorkerCount = 50
// streamQueueSize is the buffer size for the stream persistence queue.
streamQueueSize = 4096
) )
type awarenessItem struct { type awarenessItem struct {
@@ -103,11 +117,15 @@ func NewHub(messagebus messagebus.MessageBus, serverID string, logger *zap.Logge
publishDone: make(chan struct{}), publishDone: make(chan struct{}),
// bounded awareness worker pool // bounded awareness worker pool
awarenessQueue: make(chan awarenessItem, awarenessQueueSize), awarenessQueue: make(chan awarenessItem, awarenessQueueSize),
// Stream persistence worker pool
streamQueue: make(chan *Message, streamQueueSize),
streamDone: make(chan struct{}),
} }
// Start the fixed worker pool for Redis publishing // Start the fixed worker pool for Redis publishing
h.startPublishWorkers(publishWorkerCount) h.startPublishWorkers(publishWorkerCount)
h.startAwarenessWorkers(awarenessWorkerCount) h.startAwarenessWorkers(awarenessWorkerCount)
h.startStreamWorkers(streamWorkerCount)
return h return h
} }
@@ -173,6 +191,82 @@ func (h *Hub) startAwarenessWorkers(n int) {
h.logger.Info("Awareness worker pool started", zap.Int("workers", n)) h.logger.Info("Awareness worker pool started", zap.Int("workers", n))
} }
// startStreamWorkers launches n goroutines that consume from streamQueue
// and add messages to Redis Streams for durability and replay.
func (h *Hub) startStreamWorkers(n int) {
for i := 0; i < n; i++ {
go func(workerID int) {
for {
select {
case <-h.streamDone:
h.logger.Info("Stream worker exiting", zap.Int("worker_id", workerID))
return
case msg, ok := <-h.streamQueue:
if !ok {
return
}
h.addToStream(msg)
}
}
}(i)
}
h.logger.Info("Stream worker pool started", zap.Int("workers", n))
}
// encodeBase64 encodes binary data to base64 string for Redis storage
func encodeBase64(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}
// addToStream adds a message to Redis Streams for durability
func (h *Hub) addToStream(msg *Message) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
streamKey := "stream:" + msg.RoomID
// Get next sequence number atomically
seqKey := "seq:" + msg.RoomID
seq, err := h.messagebus.Incr(ctx, seqKey)
if err != nil {
h.logger.Error("Failed to increment sequence",
zap.String("room_id", msg.RoomID),
zap.Error(err))
return
}
// Encode payload as base64 (binary-safe storage)
payload := encodeBase64(msg.Data)
// Extract Yjs message type from first byte as numeric string
msgType := "0"
if len(msg.Data) > 0 {
msgType = strconv.Itoa(int(msg.Data[0]))
}
// Add entry to Stream with MAXLEN trimming
values := map[string]interface{}{
"type": "update",
"server_id": h.serverID,
"yjs_payload": payload,
"msg_type": msgType,
"seq": seq,
"timestamp": time.Now().Format(time.RFC3339),
}
_, err = h.messagebus.XAdd(ctx, streamKey, 10000, true, values)
if err != nil {
h.logger.Error("Failed to add to Stream",
zap.String("stream_key", streamKey),
zap.Int64("seq", seq),
zap.Error(err))
return
}
// Mark this document as active so the persist worker only processes active streams
_ = h.messagebus.ZAdd(ctx, "active-streams", float64(time.Now().Unix()), msg.RoomID)
}
func (h *Hub) Run() { func (h *Hub) Run() {
for { for {
select { select {
@@ -471,6 +565,7 @@ func (h *Hub) broadcastMessage(message *Message) {
// 只有本地客户端发出的消息 (sender != nil) 才推送到 Redis // 只有本地客户端发出的消息 (sender != nil) 才推送到 Redis
// P0 fix: send to bounded worker pool instead of spawning unbounded goroutines // P0 fix: send to bounded worker pool instead of spawning unbounded goroutines
if message.sender != nil && !h.fallbackMode && h.messagebus != nil { if message.sender != nil && !h.fallbackMode && h.messagebus != nil {
// 3a. Publish to Pub/Sub (real-time cross-server broadcast)
select { select {
case h.publishQueue <- message: case h.publishQueue <- message:
// Successfully queued for async publish by worker pool // Successfully queued for async publish by worker pool
@@ -479,6 +574,19 @@ func (h *Hub) broadcastMessage(message *Message) {
h.logger.Warn("Publish queue full, dropping Redis publish", h.logger.Warn("Publish queue full, dropping Redis publish",
zap.String("room_id", message.RoomID)) zap.String("room_id", message.RoomID))
} }
// 3b. Add to Stream for durability (only Type 0 updates, not Type 1 awareness)
// Type 0 = Yjs sync/update messages (document changes)
// Type 1 = Yjs awareness messages (cursors, presence) - ephemeral, skip
if len(message.Data) > 0 && message.Data[0] == 0 {
select {
case h.streamQueue <- message:
// Successfully queued for async Stream add
default:
h.logger.Warn("Stream queue full, dropping durability",
zap.String("room_id", message.RoomID))
}
}
} }
} }
@@ -504,10 +612,28 @@ func (h *Hub) broadcastToLocalClients(room *Room, data []byte, sender *Client) {
} }
} }
func (h *Hub) startRoomMessageForwarding(ctx context.Context, roomID string, msgChan <-chan []byte) { func (h *Hub) startRoomMessageForwarding(ctx context.Context, roomID string, msgChan <-chan []byte) {
h.logger.Info("Starting message forwarding from Redis to room", // Increment and log reconnection count for debugging
zap.String("room_id", roomID), h.mu.RLock()
zap.String("server_id", h.serverID), room, exists := h.rooms[roomID]
) h.mu.RUnlock()
if exists {
room.mu.Lock()
room.reconnectCount++
reconnectCount := room.reconnectCount
room.mu.Unlock()
h.logger.Info("Starting message forwarding from Redis to room",
zap.String("room_id", roomID),
zap.String("server_id", h.serverID),
zap.Int("reconnect_count", reconnectCount),
)
} else {
h.logger.Info("Starting message forwarding from Redis to room",
zap.String("room_id", roomID),
zap.String("server_id", h.serverID),
)
}
for { for {
select { select {
@@ -791,12 +917,28 @@ func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string
UserAvatar: userAvatar, UserAvatar: userAvatar,
Permission: permission, Permission: permission,
Conn: conn, Conn: conn,
send: make(chan []byte, 1024), send: make(chan []byte, 8192),
hub: hub, hub: hub,
roomID: roomID, roomID: roomID,
observedYjsIDs: make(map[uint64]uint64), observedYjsIDs: make(map[uint64]uint64),
} }
} }
// Enqueue sends a message to the client send buffer (non-blocking).
// Returns false if the buffer is full.
func (c *Client) Enqueue(message []byte) bool {
select {
case c.send <- message:
return true
default:
if c.hub != nil && c.hub.logger != nil {
c.hub.logger.Warn("Client send buffer full during replay",
zap.String("client_id", c.ID),
zap.String("room_id", c.roomID))
}
return false
}
}
func (c *Client) unregister() { func (c *Client) unregister() {
c.unregisterOnce.Do(func() { c.unregisterOnce.Do(func() {
c.hub.Unregister <- c c.hub.Unregister <- c

View File

@@ -23,7 +23,7 @@ func NewLogger(isDevelopment bool) (*zap.Logger, error) {
// 👇 关键修改:直接拉到 Fatal 级别 // 👇 关键修改:直接拉到 Fatal 级别
// 这样 Error, Warn, Info, Debug 全部都会被忽略 // 这样 Error, Warn, Info, Debug 全部都会被忽略
// 彻底消除 IO 锁竞争 // 彻底消除 IO 锁竞争
config.Level = zap.NewAtomicLevelAt(zapcore.FatalLevel) config.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
logger, err := config.Build() logger, err := config.Build()
if err != nil { if err != nil {

View File

@@ -2,6 +2,7 @@ package messagebus
import ( import (
"context" "context"
"time"
) )
// MessageBus abstracts message distribution across server instances // MessageBus abstracts message distribution across server instances
@@ -33,6 +34,72 @@ type MessageBus interface {
// Close gracefully shuts down the message bus // Close gracefully shuts down the message bus
Close() error Close() error
// ========== Redis Streams Operations ==========
// XAdd adds a new entry to a stream with optional MAXLEN trimming
XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error)
// XReadGroup reads messages from a stream using a consumer group
XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error)
// XAck acknowledges one or more messages from a consumer group
XAck(ctx context.Context, stream, group string, ids ...string) (int64, error)
// XGroupCreate creates a new consumer group for a stream
XGroupCreate(ctx context.Context, stream, group, start string) error
// XGroupCreateMkStream creates a consumer group and the stream if it doesn't exist
XGroupCreateMkStream(ctx context.Context, stream, group, start string) error
// XPending returns pending messages information for a consumer group
XPending(ctx context.Context, stream, group string) (*PendingInfo, error)
// XClaim claims pending messages from a consumer group
XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error)
// XAutoClaim claims pending messages automatically (Redis >= 6.2)
// Returns claimed messages and next start ID.
XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error)
// XRange reads a range of messages from a stream
XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error)
// XTrimMinID trims a stream to a minimum ID (time-based retention)
XTrimMinID(ctx context.Context, stream, minID string) (int64, error)
// Incr increments a counter atomically (for sequence numbers)
Incr(ctx context.Context, key string) (int64, error)
// ========== Sorted Set (ZSET) Operations ==========
// ZAdd adds a member with a score to a sorted set (used for active-stream tracking)
ZAdd(ctx context.Context, key string, score float64, member string) error
// ZRangeByScore returns members with scores between min and max
ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error)
// ZRemRangeByScore removes members with scores between min and max
ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error)
// Distributed lock helpers (used by background workers)
AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error)
RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error)
ReleaseLock(ctx context.Context, key string) error
}
// StreamMessage represents a message from a Redis Stream
type StreamMessage struct {
ID string
Values map[string]interface{}
}
// PendingInfo contains information about pending messages in a consumer group
type PendingInfo struct {
Count int64
Lower string
Upper string
Consumers map[string]int64
} }
// LocalMessageBus is a no-op implementation for single-server mode // LocalMessageBus is a no-op implementation for single-server mode
@@ -78,3 +145,73 @@ func (l *LocalMessageBus) IsHealthy() bool {
func (l *LocalMessageBus) Close() error { func (l *LocalMessageBus) Close() error {
return nil return nil
} }
// ========== Redis Streams Operations (No-op for local mode) ==========
func (l *LocalMessageBus) XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) {
return "0-0", nil
}
func (l *LocalMessageBus) XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) {
return nil, nil
}
func (l *LocalMessageBus) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) {
return 0, nil
}
func (l *LocalMessageBus) XGroupCreate(ctx context.Context, stream, group, start string) error {
return nil
}
func (l *LocalMessageBus) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error {
return nil
}
func (l *LocalMessageBus) XPending(ctx context.Context, stream, group string) (*PendingInfo, error) {
return &PendingInfo{}, nil
}
func (l *LocalMessageBus) XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) {
return nil, nil
}
func (l *LocalMessageBus) XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) {
return nil, "0-0", nil
}
func (l *LocalMessageBus) XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) {
return nil, nil
}
func (l *LocalMessageBus) XTrimMinID(ctx context.Context, stream, minID string) (int64, error) {
return 0, nil
}
func (l *LocalMessageBus) Incr(ctx context.Context, key string) (int64, error) {
return 0, nil
}
func (l *LocalMessageBus) ZAdd(ctx context.Context, key string, score float64, member string) error {
return nil
}
func (l *LocalMessageBus) ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) {
return nil, nil
}
func (l *LocalMessageBus) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) {
return 0, nil
}
func (l *LocalMessageBus) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
return true, nil
}
func (l *LocalMessageBus) RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
return true, nil
}
func (l *LocalMessageBus) ReleaseLock(ctx context.Context, key string) error {
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"net"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -88,6 +89,23 @@ func NewRedisMessageBus(redisURL string, serverID string, logger *zap.Logger) (*
// - Redis will handle stale connections via TCP keepalive // - Redis will handle stale connections via TCP keepalive
opts.ConnMaxLifetime = 1 * time.Hour opts.ConnMaxLifetime = 1 * time.Hour
// ================================
// Socket-Level Timeout Configuration (prevents indefinite hangs)
// ================================
// Without these, TCP reads/writes block indefinitely when Redis is unresponsive,
// causing OS-level timeouts (60-120s) instead of application-level control.
// DialTimeout: How long to wait for initial connection establishment
opts.DialTimeout = 5 * time.Second
// ReadTimeout: Maximum time for socket read operations
// - 30s is appropriate for PubSub (long intervals between messages are normal)
// - Prevents indefinite blocking when Redis hangs
opts.ReadTimeout = 30 * time.Second
// WriteTimeout: Maximum time for socket write operations
opts.WriteTimeout = 5 * time.Second
client := goredis.NewClient(opts) client := goredis.NewClient(opts)
// ================================ // ================================
@@ -215,12 +233,15 @@ func (r *RedisMessageBus) readLoop(ctx context.Context, roomID string, sub *subs
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
r.logger.Warn("PubSub initial subscription failed, retrying with backoff",
zap.String("roomID", roomID),
zap.Error(err),
zap.Duration("backoff", backoff),
)
time.Sleep(backoff) time.Sleep(backoff)
if backoff < maxBackoff { backoff = backoff * 2
backoff *= 2 if backoff > maxBackoff {
if backoff > maxBackoff { backoff = maxBackoff
backoff = maxBackoff
}
} }
continue continue
} }
@@ -242,12 +263,15 @@ func (r *RedisMessageBus) readLoop(ctx context.Context, roomID string, sub *subs
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
r.logger.Warn("PubSub receive failed, retrying with backoff",
zap.String("roomID", roomID),
zap.Error(err),
zap.Duration("backoff", backoff),
)
time.Sleep(backoff) time.Sleep(backoff)
if backoff < maxBackoff { backoff = backoff * 2
backoff *= 2 if backoff > maxBackoff {
if backoff > maxBackoff { backoff = maxBackoff
backoff = maxBackoff
}
} }
} }
} }
@@ -261,12 +285,15 @@ func (r *RedisMessageBus) receiveOnce(ctx context.Context, roomID string, pubsub
msg, err := pubsub.ReceiveTimeout(ctx, 5*time.Second) msg, err := pubsub.ReceiveTimeout(ctx, 5*time.Second)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { if ctx.Err() != nil {
return err return ctx.Err()
} }
if errors.Is(err, goredis.Nil) { if errors.Is(err, goredis.Nil) {
continue continue
} }
if isTimeoutErr(err) {
continue
}
r.logger.Warn("pubsub receive error, closing subscription", r.logger.Warn("pubsub receive error, closing subscription",
zap.String("roomID", roomID), zap.String("roomID", roomID),
zap.Error(err), zap.Error(err),
@@ -308,6 +335,17 @@ func (r *RedisMessageBus) receiveOnce(ctx context.Context, roomID string, pubsub
} }
} }
func isTimeoutErr(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
// Unsubscribe stops listening to a room // Unsubscribe stops listening to a room
func (r *RedisMessageBus) Unsubscribe(ctx context.Context, roomID string) error { func (r *RedisMessageBus) Unsubscribe(ctx context.Context, roomID string) error {
r.subMu.Lock() r.subMu.Lock()
@@ -430,7 +468,7 @@ func (r *RedisMessageBus) DeleteAwareness(ctx context.Context, roomID string, cl
// IsHealthy checks Redis connectivity // IsHealthy checks Redis connectivity
func (r *RedisMessageBus) IsHealthy() bool { func (r *RedisMessageBus) IsHealthy() bool {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
// 只有 Ping 成功且没有报错,才认为服务是健康的 // 只有 Ping 成功且没有报错,才认为服务是健康的
@@ -516,3 +554,223 @@ func (r *RedisMessageBus) ClearAllAwareness(ctx context.Context, roomID string)
// 直接使用 Del 命令删除整个 Key // 直接使用 Del 命令删除整个 Key
return r.client.Del(ctx, key).Err() return r.client.Del(ctx, key).Err()
} }
// ========== Redis Streams Operations ==========
// XAdd adds a new entry to a stream with optional MAXLEN trimming
func (r *RedisMessageBus) XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) {
result := r.client.XAdd(ctx, &goredis.XAddArgs{
Stream: stream,
MaxLen: maxLen,
Approx: approx,
Values: values,
})
return result.Val(), result.Err()
}
// XReadGroup reads messages from a stream using a consumer group
func (r *RedisMessageBus) XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) {
result := r.client.XReadGroup(ctx, &goredis.XReadGroupArgs{
Group: group,
Consumer: consumer,
Streams: streams,
Count: count,
Block: block,
})
if err := result.Err(); err != nil {
// Timeout is not an error, just no new messages
if err == goredis.Nil {
return nil, nil
}
return nil, err
}
// Convert go-redis XStream to our StreamMessage format
var messages []StreamMessage
for _, stream := range result.Val() {
for _, msg := range stream.Messages {
messages = append(messages, StreamMessage{
ID: msg.ID,
Values: msg.Values,
})
}
}
return messages, nil
}
// XAck acknowledges one or more messages from a consumer group
func (r *RedisMessageBus) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) {
result := r.client.XAck(ctx, stream, group, ids...)
return result.Val(), result.Err()
}
// XGroupCreate creates a new consumer group for a stream
func (r *RedisMessageBus) XGroupCreate(ctx context.Context, stream, group, start string) error {
return r.client.XGroupCreate(ctx, stream, group, start).Err()
}
// XGroupCreateMkStream creates a consumer group and the stream if it doesn't exist
func (r *RedisMessageBus) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error {
return r.client.XGroupCreateMkStream(ctx, stream, group, start).Err()
}
// XPending returns pending messages information for a consumer group
func (r *RedisMessageBus) XPending(ctx context.Context, stream, group string) (*PendingInfo, error) {
result := r.client.XPending(ctx, stream, group)
if err := result.Err(); err != nil {
return nil, err
}
pending := result.Val()
consumers := make(map[string]int64)
for name, count := range pending.Consumers {
consumers[name] = count
}
return &PendingInfo{
Count: pending.Count,
Lower: pending.Lower,
Upper: pending.Higher, // go-redis uses "Higher" instead of "Upper"
Consumers: consumers,
}, nil
}
// XClaim claims pending messages from a consumer group
func (r *RedisMessageBus) XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) {
result := r.client.XClaim(ctx, &goredis.XClaimArgs{
Stream: stream,
Group: group,
Consumer: consumer,
MinIdle: minIdleTime,
Messages: ids,
})
if err := result.Err(); err != nil {
return nil, err
}
// Convert go-redis XMessage to our StreamMessage format
var messages []StreamMessage
for _, msg := range result.Val() {
messages = append(messages, StreamMessage{
ID: msg.ID,
Values: msg.Values,
})
}
return messages, nil
}
// XAutoClaim claims pending messages automatically (Redis >= 6.2)
func (r *RedisMessageBus) XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) {
result := r.client.XAutoClaim(ctx, &goredis.XAutoClaimArgs{
Stream: stream,
Group: group,
Consumer: consumer,
MinIdle: minIdleTime,
Start: start,
Count: count,
})
msgs, nextStart, err := result.Result()
if err != nil {
return nil, "", err
}
messages := make([]StreamMessage, 0, len(msgs))
for _, msg := range msgs {
messages = append(messages, StreamMessage{
ID: msg.ID,
Values: msg.Values,
})
}
return messages, nextStart, nil
}
// XRange reads a range of messages from a stream
func (r *RedisMessageBus) XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) {
result := r.client.XRange(ctx, stream, start, end)
if err := result.Err(); err != nil {
return nil, err
}
// Convert go-redis XMessage to our StreamMessage format
var messages []StreamMessage
for _, msg := range result.Val() {
messages = append(messages, StreamMessage{
ID: msg.ID,
Values: msg.Values,
})
}
return messages, nil
}
// XTrimMinID trims a stream to a minimum ID (time-based retention)
func (r *RedisMessageBus) XTrimMinID(ctx context.Context, stream, minID string) (int64, error) {
// Use XTRIM with MINID and approximation (~) for efficiency
// LIMIT clause prevents blocking Redis during large trims
result := r.client.Do(ctx, "XTRIM", stream, "MINID", "~", minID, "LIMIT", 1000)
if err := result.Err(); err != nil {
return 0, err
}
// Result is the number of entries removed
trimmed, err := result.Int64()
if err != nil {
return 0, err
}
return trimmed, nil
}
// ========== Sorted Set (ZSET) Operations ==========
// ZAdd adds a member with a score to a sorted set
func (r *RedisMessageBus) ZAdd(ctx context.Context, key string, score float64, member string) error {
return r.client.ZAdd(ctx, key, goredis.Z{Score: score, Member: member}).Err()
}
// ZRangeByScore returns members with scores between min and max
func (r *RedisMessageBus) ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) {
return r.client.ZRangeByScore(ctx, key, &goredis.ZRangeBy{
Min: strconv.FormatFloat(min, 'f', -1, 64),
Max: strconv.FormatFloat(max, 'f', -1, 64),
}).Result()
}
// ZRemRangeByScore removes members with scores between min and max
func (r *RedisMessageBus) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) {
return r.client.ZRemRangeByScore(ctx, key,
strconv.FormatFloat(min, 'f', -1, 64),
strconv.FormatFloat(max, 'f', -1, 64),
).Result()
}
// Incr increments a counter atomically (for sequence numbers)
func (r *RedisMessageBus) Incr(ctx context.Context, key string) (int64, error) {
result := r.client.Incr(ctx, key)
return result.Val(), result.Err()
}
// AcquireLock attempts to acquire a distributed lock with TTL
func (r *RedisMessageBus) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
return r.client.SetNX(ctx, key, r.serverID, ttl).Result()
}
// RefreshLock extends the TTL on an existing lock
func (r *RedisMessageBus) RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
result := r.client.SetArgs(ctx, key, r.serverID, goredis.SetArgs{
Mode: "XX",
TTL: ttl,
})
if err := result.Err(); err != nil {
return false, err
}
return result.Val() == "OK", nil
}
// ReleaseLock releases a distributed lock
func (r *RedisMessageBus) ReleaseLock(ctx context.Context, key string) error {
return r.client.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,15 @@
package models
import (
"time"
"github.com/google/uuid"
)
// StreamCheckpoint tracks the last processed Redis Stream entry per document
type StreamCheckpoint struct {
DocumentID uuid.UUID `json:"document_id"`
LastStreamID string `json:"last_stream_id"`
LastSeq int64 `json:"last_seq"`
UpdatedAt time.Time `json:"updated_at"`
}

View File

@@ -53,6 +53,15 @@ type Store interface {
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error) GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error) GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
// Stream checkpoint operations
UpsertStreamCheckpoint(ctx context.Context, documentID uuid.UUID, streamID string, seq int64) error
GetStreamCheckpoint(ctx context.Context, documentID uuid.UUID) (*models.StreamCheckpoint, error)
// Update history (WAL) operations
InsertUpdateHistoryBatch(ctx context.Context, entries []UpdateHistoryEntry) error
ListUpdateHistoryAfterSeq(ctx context.Context, documentID uuid.UUID, afterSeq int64, limit int) ([]UpdateHistoryEntry, error)
DeleteUpdateHistoryUpToSeq(ctx context.Context, documentID uuid.UUID, maxSeq int64) error
Close() error Close() error
} }

View File

@@ -0,0 +1,46 @@
package store
import (
"context"
"fmt"
"github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/google/uuid"
)
// UpsertStreamCheckpoint creates or updates the stream checkpoint for a document
func (s *PostgresStore) UpsertStreamCheckpoint(ctx context.Context, documentID uuid.UUID, streamID string, seq int64) error {
query := `
INSERT INTO stream_checkpoints (document_id, last_stream_id, last_seq, updated_at)
VALUES ($1, $2, $3, NOW())
ON CONFLICT (document_id)
DO UPDATE SET last_stream_id = EXCLUDED.last_stream_id,
last_seq = EXCLUDED.last_seq,
updated_at = NOW()
`
if _, err := s.db.ExecContext(ctx, query, documentID, streamID, seq); err != nil {
return fmt.Errorf("failed to upsert stream checkpoint: %w", err)
}
return nil
}
// GetStreamCheckpoint retrieves the stream checkpoint for a document
func (s *PostgresStore) GetStreamCheckpoint(ctx context.Context, documentID uuid.UUID) (*models.StreamCheckpoint, error) {
query := `
SELECT document_id, last_stream_id, last_seq, updated_at
FROM stream_checkpoints
WHERE document_id = $1
`
var checkpoint models.StreamCheckpoint
if err := s.db.QueryRowContext(ctx, query, documentID).Scan(
&checkpoint.DocumentID,
&checkpoint.LastStreamID,
&checkpoint.LastSeq,
&checkpoint.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("failed to get stream checkpoint: %w", err)
}
return &checkpoint, nil
}

View File

@@ -71,10 +71,14 @@ func SetupTestDB(t *testing.T) (*PostgresStore, func()) {
// Run migrations // Run migrations
scriptsDir := filepath.Join("..", "..", "scripts") scriptsDir := filepath.Join("..", "..", "scripts")
migrations := []string{ migrations := []string{
"init.sql", "000_extensions.sql",
"001_add_users_and_sessions.sql", "001_init_schema.sql",
"002_add_document_shares.sql", "002_add_users_and_sessions.sql",
"003_add_public_sharing.sql", "003_add_document_shares.sql",
"004_add_public_sharing.sql",
"005_add_share_link_permission.sql",
"010_add_stream_checkpoints.sql",
"011_add_update_history.sql",
} }
for _, migration := range migrations { for _, migration := range migrations {
@@ -107,6 +111,8 @@ func SetupTestDB(t *testing.T) (*PostgresStore, func()) {
func TruncateAllTables(ctx context.Context, store *PostgresStore) error { func TruncateAllTables(ctx context.Context, store *PostgresStore) error {
tables := []string{ tables := []string{
"document_updates", "document_updates",
"document_update_history",
"stream_checkpoints",
"document_shares", "document_shares",
"sessions", "sessions",
"documents", "documents",

View File

@@ -0,0 +1,115 @@
package store
import (
"context"
"fmt"
"strings"
"time"
"unicode/utf8"
"github.com/google/uuid"
)
// UpdateHistoryEntry represents a persisted update from Redis Streams
// used for recovery and replay.
type UpdateHistoryEntry struct {
DocumentID uuid.UUID
StreamID string
Seq int64
Payload []byte
MsgType string
ServerID string
CreatedAt time.Time
}
// InsertUpdateHistoryBatch inserts update history entries in a single batch.
// Uses ON CONFLICT DO NOTHING to make inserts idempotent.
func (s *PostgresStore) InsertUpdateHistoryBatch(ctx context.Context, entries []UpdateHistoryEntry) error {
if len(entries) == 0 {
return nil
}
var sb strings.Builder
sb.WriteString("INSERT INTO document_update_history (document_id, stream_id, seq, payload, msg_type, server_id, created_at) VALUES ")
args := make([]interface{}, 0, len(entries)*7)
for i, e := range entries {
if i > 0 {
sb.WriteString(",")
}
base := i*7 + 1
sb.WriteString(fmt.Sprintf("($%d,$%d,$%d,$%d,$%d,$%d,$%d)", base, base+1, base+2, base+3, base+4, base+5, base+6))
msgType := sanitizeTextForDB(e.MsgType)
serverID := sanitizeTextForDB(e.ServerID)
args = append(args, e.DocumentID, e.StreamID, e.Seq, e.Payload, nullIfEmpty(msgType), nullIfEmpty(serverID), e.CreatedAt)
}
// Idempotent insert
sb.WriteString(" ON CONFLICT (document_id, stream_id) DO NOTHING")
if _, err := s.db.ExecContext(ctx, sb.String(), args...); err != nil {
return fmt.Errorf("failed to insert update history batch: %w", err)
}
return nil
}
// ListUpdateHistoryAfterSeq returns updates with seq greater than afterSeq, ordered by seq.
func (s *PostgresStore) ListUpdateHistoryAfterSeq(ctx context.Context, documentID uuid.UUID, afterSeq int64, limit int) ([]UpdateHistoryEntry, error) {
if limit <= 0 {
limit = 1000
}
query := `
SELECT document_id, stream_id, seq, payload, COALESCE(msg_type, ''), COALESCE(server_id, ''), created_at
FROM document_update_history
WHERE document_id = $1 AND seq > $2
ORDER BY seq ASC
LIMIT $3
`
rows, err := s.db.QueryContext(ctx, query, documentID, afterSeq, limit)
if err != nil {
return nil, fmt.Errorf("failed to list update history: %w", err)
}
defer rows.Close()
var results []UpdateHistoryEntry
for rows.Next() {
var e UpdateHistoryEntry
if err := rows.Scan(&e.DocumentID, &e.StreamID, &e.Seq, &e.Payload, &e.MsgType, &e.ServerID, &e.CreatedAt); err != nil {
return nil, fmt.Errorf("failed to scan update history: %w", err)
}
results = append(results, e)
}
return results, nil
}
// DeleteUpdateHistoryUpToSeq deletes updates with seq <= maxSeq for a document.
func (s *PostgresStore) DeleteUpdateHistoryUpToSeq(ctx context.Context, documentID uuid.UUID, maxSeq int64) error {
query := `
DELETE FROM document_update_history
WHERE document_id = $1 AND seq <= $2
`
if _, err := s.db.ExecContext(ctx, query, documentID, maxSeq); err != nil {
return fmt.Errorf("failed to delete update history: %w", err)
}
return nil
}
func nullIfEmpty(s string) interface{} {
if s == "" {
return nil
}
return s
}
func sanitizeTextForDB(s string) string {
if s == "" {
return ""
}
if strings.IndexByte(s, 0) >= 0 {
return ""
}
if !utf8.ValidString(s) {
return ""
}
return s
}

View File

@@ -0,0 +1,322 @@
package workers
import (
"context"
"encoding/base64"
"fmt"
"runtime/debug"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
"github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/google/uuid"
"go.uber.org/zap"
)
const (
updatePersistGroupName = "update-persist-worker"
updatePersistLockKey = "lock:update-persist-worker"
updatePersistLockTTL = 30 * time.Second
updatePersistTick = 2 * time.Second
updateReadCount = 200
updateReadBlock = -1 // negative → go-redis omits BLOCK clause → non-blocking
updateBatchSize = 500
updateSafeSeqLag = int64(1000)
updateAutoClaimIdle = 30 * time.Second
updateHeartbeatEvery = 30 * time.Second
)
// StartUpdatePersistWorker persists Redis Stream updates into Postgres for recovery.
func StartUpdatePersistWorker(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) {
if msgBus == nil || dbStore == nil {
return
}
for {
func() {
defer func() {
if r := recover(); r != nil {
logWorker(logger, "Update persist worker panic",
zap.Any("panic", r),
zap.ByteString("stack", debug.Stack()))
}
}()
select {
case <-ctx.Done():
return
default:
}
acquired, err := msgBus.AcquireLock(ctx, updatePersistLockKey, updatePersistLockTTL)
if err != nil {
logWorker(logger, "Failed to acquire update persist worker lock", zap.Error(err))
time.Sleep(updatePersistTick)
return
}
if !acquired {
time.Sleep(updatePersistTick)
return
}
logWorker(logger, "Update persist worker lock acquired", zap.String("server_id", serverID))
runUpdatePersistWorker(ctx, msgBus, dbStore, logger, serverID)
}()
select {
case <-ctx.Done():
return
default:
}
// If the worker exited (including panic), pause briefly before retry.
time.Sleep(updatePersistTick)
}
}
func runUpdatePersistWorker(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) {
ticker := time.NewTicker(updatePersistTick)
defer ticker.Stop()
refreshTicker := time.NewTicker(updatePersistLockTTL / 2)
defer refreshTicker.Stop()
heartbeatTicker := time.NewTicker(updateHeartbeatEvery)
defer heartbeatTicker.Stop()
for {
select {
case <-ctx.Done():
_ = msgBus.ReleaseLock(ctx, updatePersistLockKey)
return
case <-refreshTicker.C:
ok, err := msgBus.RefreshLock(ctx, updatePersistLockKey, updatePersistLockTTL)
if err != nil || !ok {
logWorker(logger, "Update persist worker lock lost", zap.Error(err))
_ = msgBus.ReleaseLock(ctx, updatePersistLockKey)
return
}
case <-heartbeatTicker.C:
if logger != nil {
logger.Debug("Update persist worker heartbeat", zap.String("server_id", serverID))
}
case <-ticker.C:
if err := processUpdatePersistence(ctx, msgBus, dbStore, logger, serverID); err != nil {
logWorker(logger, "Update persist worker tick failed", zap.Error(err))
}
}
}
}
func processUpdatePersistence(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) error {
// Only process documents with recent stream activity (active in the last 60 seconds)
cutoff := float64(time.Now().Add(-60 * time.Second).Unix())
activeDocIDs, err := msgBus.ZRangeByScore(ctx, "active-streams", cutoff, float64(time.Now().Unix()))
if err != nil {
return fmt.Errorf("failed to get active streams: %w", err)
}
// Prune stale entries older than 5 minutes (best-effort cleanup)
stale := float64(time.Now().Add(-5 * time.Minute).Unix())
if _, err := msgBus.ZRemRangeByScore(ctx, "active-streams", 0, stale); err != nil {
logWorker(logger, "Failed to prune stale active-streams entries", zap.Error(err))
}
for _, docIDStr := range activeDocIDs {
docID, err := uuid.Parse(docIDStr)
if err != nil {
logWorker(logger, "Invalid document ID in active-streams", zap.String("doc_id", docIDStr))
continue
}
streamKey := "stream:" + docIDStr
if err := ensureConsumerGroup(ctx, msgBus, streamKey, updatePersistGroupName); err != nil {
logWorker(logger, "Failed to ensure update persist consumer group", zap.String("stream", streamKey), zap.Error(err))
continue
}
var ackIDs []string
docEntries := make([]store.UpdateHistoryEntry, 0, updateBatchSize)
// First, try to claim idle pending messages (e.g., from previous crashes)
claimed, _, err := msgBus.XAutoClaim(ctx, streamKey, updatePersistGroupName, serverID, updateAutoClaimIdle, "0-0", updateReadCount)
if err != nil {
logWorker(logger, "XAutoClaim failed", zap.String("stream", streamKey), zap.Error(err))
} else if len(claimed) > 0 {
collectStreamMessages(ctx, msgBus, dbStore, logger, docID, streamKey, claimed, &docEntries, &ackIDs)
}
messages, err := msgBus.XReadGroup(ctx, updatePersistGroupName, serverID, []string{streamKey, ">"}, updateReadCount, updateReadBlock)
if err != nil {
logWorker(logger, "XReadGroup failed", zap.String("stream", streamKey), zap.Error(err))
continue
}
if len(messages) > 0 {
collectStreamMessages(ctx, msgBus, dbStore, logger, docID, streamKey, messages, &docEntries, &ackIDs)
}
if len(docEntries) > 0 {
if err := dbStore.InsertUpdateHistoryBatch(ctx, docEntries); err != nil {
logWorker(logger, "Failed to insert update history batch", zap.Error(err))
// Skip ACK to retry on next tick
continue
}
}
if len(ackIDs) > 0 {
if _, err := msgBus.XAck(ctx, streamKey, updatePersistGroupName, ackIDs...); err != nil {
logWorker(logger, "XAck failed", zap.String("stream", streamKey), zap.Error(err))
}
}
}
return nil
}
func collectStreamMessages(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, documentID uuid.UUID, streamKey string, messages []messagebus.StreamMessage, docEntries *[]store.UpdateHistoryEntry, ackIDs *[]string) {
for _, msg := range messages {
msgType := getString(msg.Values["type"])
switch msgType {
case "update":
payloadB64 := getString(msg.Values["yjs_payload"])
payload, err := base64.StdEncoding.DecodeString(payloadB64)
if err != nil {
logWorker(logger, "Failed to decode update payload",
zap.String("stream", streamKey),
zap.String("stream_id", msg.ID),
zap.Error(err))
continue
}
seq := parseInt64(msg.Values["seq"])
msgType := normalizeMsgType(msg.Values["msg_type"])
serverID := sanitizeText(getString(msg.Values["server_id"]))
entry := store.UpdateHistoryEntry{
DocumentID: documentID,
StreamID: msg.ID,
Seq: seq,
Payload: payload,
MsgType: msgType,
ServerID: serverID,
CreatedAt: time.Now().UTC(),
}
*docEntries = append(*docEntries, entry)
case "snapshot":
seq := parseInt64(msg.Values["seq"])
if seq > 0 {
if err := dbStore.UpsertStreamCheckpoint(ctx, documentID, msg.ID, seq); err != nil {
logWorker(logger, "Failed to upsert stream checkpoint from snapshot marker",
zap.String("document_id", documentID.String()),
zap.Error(err))
}
// Retention: prune DB history based on checkpoint (best-effort)
maxSeq := seq - updateSafeSeqLag
if maxSeq > 0 {
if err := dbStore.DeleteUpdateHistoryUpToSeq(ctx, documentID, maxSeq); err != nil {
logWorker(logger, "Failed to prune update history",
zap.String("document_id", documentID.String()),
zap.Error(err))
}
}
// Trim Redis stream to avoid unbounded growth (best-effort)
if _, err := msgBus.XTrimMinID(ctx, streamKey, msg.ID); err != nil {
logWorker(logger, "Failed to trim Redis stream",
zap.String("stream", streamKey),
zap.Error(err))
}
}
}
*ackIDs = append(*ackIDs, msg.ID)
}
}
func ensureConsumerGroup(ctx context.Context, msgBus messagebus.MessageBus, streamKey, group string) error {
if err := msgBus.XGroupCreateMkStream(ctx, streamKey, group, "0-0"); err != nil {
if !isBusyGroup(err) {
return err
}
}
return nil
}
func isBusyGroup(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "BUSYGROUP")
}
func getString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprint(v)
}
}
func parseInt64(value interface{}) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case uint64:
return int64(v)
case string:
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
return parsed
}
case []byte:
if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil {
return parsed
}
}
return 0
}
func sanitizeText(s string) string {
if s == "" {
return s
}
if strings.IndexByte(s, 0) >= 0 {
return ""
}
if !utf8.ValidString(s) {
return ""
}
return s
}
func normalizeMsgType(value interface{}) string {
switch v := value.(type) {
case string:
if v == "" {
return ""
}
if len(v) == 1 {
return strconv.Itoa(int(v[0]))
}
return sanitizeText(v)
case []byte:
if len(v) == 0 {
return ""
}
if len(v) == 1 {
return strconv.Itoa(int(v[0]))
}
return sanitizeText(string(v))
default:
return sanitizeText(fmt.Sprint(v))
}
}
func logWorker(logger *zap.Logger, msg string, fields ...zap.Field) {
if logger == nil {
return
}
logger.Info(msg, fields...)
}

View File

@@ -0,0 +1,12 @@
-- Migration: Add stream checkpoints table for Redis Streams durability
-- This table tracks last processed stream position per document
CREATE TABLE IF NOT EXISTS stream_checkpoints (
document_id UUID PRIMARY KEY REFERENCES documents(id) ON DELETE CASCADE,
last_stream_id TEXT NOT NULL,
last_seq BIGINT NOT NULL DEFAULT 0,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_stream_checkpoints_updated_at
ON stream_checkpoints(updated_at DESC);

View File

@@ -0,0 +1,22 @@
-- Migration: Add update history table for Redis Stream WAL
-- This table stores per-update payloads for recovery and replay
CREATE TABLE IF NOT EXISTS document_update_history (
id BIGSERIAL PRIMARY KEY,
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
stream_id TEXT NOT NULL,
seq BIGINT NOT NULL,
payload BYTEA NOT NULL,
msg_type TEXT,
server_id TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_stream_id
ON document_update_history(document_id, stream_id);
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_seq
ON document_update_history(document_id, seq);
CREATE INDEX IF NOT EXISTS idx_update_history_document_seq
ON document_update_history(document_id, seq);

View File

@@ -0,0 +1,3 @@
-- Add 'guest' as a valid provider for guest mode login
ALTER TABLE users DROP CONSTRAINT IF EXISTS users_provider_check;
ALTER TABLE users ADD CONSTRAINT users_provider_check CHECK (provider IN ('google', 'github', 'guest'));

271
backend/scripts/init.sql Normal file
View File

@@ -0,0 +1,271 @@
-- Migration: Create required PostgreSQL extensions
-- Extensions must be created before other migrations can use them
-- uuid-ossp: Provides functions for generating UUIDs (uuid_generate_v4())
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
-- pgcrypto: Provides cryptographic functions (used for token hashing)
CREATE EXTENSION IF NOT EXISTS "pgcrypto";
-- Initialize database schema for realtime collaboration
-- This is the base schema that creates core tables for documents and updates
CREATE TABLE IF NOT EXISTS documents (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
name VARCHAR(255) NOT NULL,
type VARCHAR(50) NOT NULL CHECK (type IN ('editor', 'kanban')),
yjs_state BYTEA,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX idx_documents_type ON documents(type);
CREATE INDEX idx_documents_created_at ON documents(created_at DESC);
-- Table for storing incremental updates (for history tracking)
CREATE TABLE IF NOT EXISTS document_updates (
id SERIAL PRIMARY KEY,
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
update BYTEA NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX idx_updates_document_id ON document_updates(document_id);
CREATE INDEX idx_updates_created_at ON document_updates(created_at DESC);
-- Migration: Add users and sessions tables for authentication
-- Run this before 002_add_document_shares.sql
-- Enable UUID extension
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
-- Users table
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
email VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
avatar_url TEXT,
provider VARCHAR(50) NOT NULL CHECK (provider IN ('google', 'github')),
provider_user_id VARCHAR(255) NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
last_login_at TIMESTAMPTZ,
UNIQUE(provider, provider_user_id)
);
CREATE INDEX idx_users_email ON users(email);
CREATE INDEX idx_users_provider ON users(provider, provider_user_id);
COMMENT ON TABLE users IS 'Stores user accounts from OAuth providers';
COMMENT ON COLUMN users.provider IS 'OAuth provider: google or github';
COMMENT ON COLUMN users.provider_user_id IS 'User ID from OAuth provider';
-- Sessions table
CREATE TABLE IF NOT EXISTS sessions (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash VARCHAR(64) NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
user_agent TEXT,
ip_address VARCHAR(45),
UNIQUE(token_hash)
);
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
CREATE INDEX idx_sessions_token_hash ON sessions(token_hash);
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
COMMENT ON TABLE sessions IS 'Stores active JWT sessions for revocation support';
COMMENT ON COLUMN sessions.token_hash IS 'SHA-256 hash of JWT token';
COMMENT ON COLUMN sessions.user_agent IS 'User agent string for device tracking';
-- Add owner_id to documents table if it doesn't exist
ALTER TABLE documents ADD COLUMN IF NOT EXISTS owner_id UUID REFERENCES users(id) ON DELETE SET NULL;
CREATE INDEX IF NOT EXISTS idx_documents_owner_id ON documents(owner_id);
COMMENT ON COLUMN documents.owner_id IS 'User who created the document';
-- Migration: Add document sharing with permissions
-- Run against existing database
CREATE TABLE IF NOT EXISTS document_shares (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
permission VARCHAR(20) NOT NULL CHECK (permission IN ('view', 'edit')),
created_at TIMESTAMPTZ DEFAULT NOW(),
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
UNIQUE(document_id, user_id)
);
CREATE INDEX idx_shares_document_id ON document_shares(document_id);
CREATE INDEX idx_shares_user_id ON document_shares(user_id);
CREATE INDEX idx_shares_permission ON document_shares(document_id, permission);
COMMENT ON TABLE document_shares IS 'Stores per-user document access permissions';
COMMENT ON COLUMN document_shares.permission IS 'Access level: view (read-only) or edit (read-write)';
-- Migration: Add public sharing support via share tokens
-- Dependencies: Run after 002_add_document_shares.sql
-- Purpose: Add share_token and is_public columns used by share link feature
-- Add columns for public sharing
ALTER TABLE documents ADD COLUMN IF NOT EXISTS share_token VARCHAR(255);
ALTER TABLE documents ADD COLUMN IF NOT EXISTS is_public BOOLEAN DEFAULT false NOT NULL;
-- Create indexes for performance
CREATE INDEX IF NOT EXISTS idx_documents_share_token ON documents(share_token) WHERE share_token IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_documents_is_public ON documents(is_public) WHERE is_public = true;
-- Constraint: public documents must have a token
-- This ensures data integrity - a document can't be public without a share token
ALTER TABLE documents ADD CONSTRAINT check_public_has_token
CHECK (is_public = false OR (is_public = true AND share_token IS NOT NULL));
-- Documentation
COMMENT ON COLUMN documents.share_token IS 'Public share token for link-based access (base64-encoded random string, 32 bytes)';
COMMENT ON COLUMN documents.is_public IS 'Whether document is publicly accessible via share link';
-- Migration: Add permission column for public share links
-- Dependencies: Run after 003_add_public_sharing.sql
-- Purpose: Store permission level (view/edit) for public share links
-- Add permission column to documents table
ALTER TABLE documents ADD COLUMN IF NOT EXISTS share_permission VARCHAR(20) DEFAULT 'edit' CHECK (share_permission IN ('view', 'edit'));
-- Create index for performance
CREATE INDEX IF NOT EXISTS idx_documents_share_permission ON documents(share_permission) WHERE is_public = true;
-- Documentation
COMMENT ON COLUMN documents.share_permission IS 'Permission level for public share link: view (read-only) or edit (read-write). Defaults to edit for backward compatibility.';
-- Migration: Add OAuth token storage
-- This table stores OAuth2 access tokens and refresh tokens from external providers
-- Used for refreshing user sessions without re-authentication
CREATE TABLE IF NOT EXISTS oauth_tokens (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
provider VARCHAR(50) NOT NULL,
access_token TEXT NOT NULL,
refresh_token TEXT,
token_type VARCHAR(50) DEFAULT 'Bearer',
expires_at TIMESTAMPTZ NOT NULL,
scope TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
CONSTRAINT oauth_tokens_user_id_provider_key UNIQUE (user_id, provider)
);
CREATE INDEX idx_oauth_tokens_user_id ON oauth_tokens(user_id);
-- Migration: Add document version history support
-- This migration creates the version history table, adds tracking columns,
-- and provides a helper function for version numbering
-- Create document versions table for storing version snapshots
CREATE TABLE IF NOT EXISTS document_versions (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
yjs_snapshot BYTEA NOT NULL,
text_preview TEXT,
version_number INTEGER NOT NULL,
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
version_label TEXT,
is_auto_generated BOOLEAN DEFAULT true,
created_at TIMESTAMPTZ DEFAULT NOW(),
CONSTRAINT unique_document_version UNIQUE(document_id, version_number)
);
CREATE INDEX idx_document_versions_document_id ON document_versions(document_id, created_at DESC);
CREATE INDEX idx_document_versions_created_by ON document_versions(created_by);
-- Add version tracking columns to documents table
ALTER TABLE documents ADD COLUMN IF NOT EXISTS version_count INTEGER DEFAULT 0;
ALTER TABLE documents ADD COLUMN IF NOT EXISTS last_snapshot_at TIMESTAMPTZ;
-- Function to get the next version number for a document
-- This ensures version numbers are sequential and unique per document
CREATE OR REPLACE FUNCTION get_next_version_number(p_document_id UUID)
RETURNS INTEGER AS $$
DECLARE
next_version INTEGER;
BEGIN
SELECT COALESCE(MAX(version_number), 0) + 1
INTO next_version
FROM document_versions
WHERE document_id = p_document_id;
RETURN next_version;
END;
$$ LANGUAGE plpgsql;
-- Migration: Enable Row Level Security (RLS) on all tables
-- This enables RLS but uses permissive policies to allow all operations
-- Authorization is still handled by the Go backend middleware
-- Enable RLS on all tables
ALTER TABLE users ENABLE ROW LEVEL SECURITY;
ALTER TABLE sessions ENABLE ROW LEVEL SECURITY;
ALTER TABLE oauth_tokens ENABLE ROW LEVEL SECURITY;
ALTER TABLE documents ENABLE ROW LEVEL SECURITY;
ALTER TABLE document_updates ENABLE ROW LEVEL SECURITY;
ALTER TABLE document_shares ENABLE ROW LEVEL SECURITY;
ALTER TABLE document_versions ENABLE ROW LEVEL SECURITY;
-- Create permissive policies that allow all operations
-- This maintains current behavior where backend handles authorization
-- Users table
CREATE POLICY "Allow all operations on users" ON users FOR ALL USING (true);
-- Sessions table
CREATE POLICY "Allow all operations on sessions" ON sessions FOR ALL USING (true);
-- OAuth tokens table
CREATE POLICY "Allow all operations on oauth_tokens" ON oauth_tokens FOR ALL USING (true);
-- Documents table
CREATE POLICY "Allow all operations on documents" ON documents FOR ALL USING (true);
-- Document updates table
CREATE POLICY "Allow all operations on document_updates" ON document_updates FOR ALL USING (true);
-- Document shares table
CREATE POLICY "Allow all operations on document_shares" ON document_shares FOR ALL USING (true);
-- Document versions table
CREATE POLICY "Allow all operations on document_versions" ON document_versions FOR ALL USING (true);
-- Migration: Add stream checkpoints table for Redis Streams durability
-- This table tracks last processed stream position per document
CREATE TABLE IF NOT EXISTS stream_checkpoints (
document_id UUID PRIMARY KEY REFERENCES documents(id) ON DELETE CASCADE,
last_stream_id TEXT NOT NULL,
last_seq BIGINT NOT NULL DEFAULT 0,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_stream_checkpoints_updated_at
ON stream_checkpoints(updated_at DESC);
-- Migration: Add update history table for Redis Stream WAL
-- This table stores per-update payloads for recovery and replay
CREATE TABLE IF NOT EXISTS document_update_history (
id BIGSERIAL PRIMARY KEY,
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
stream_id TEXT NOT NULL,
seq BIGINT NOT NULL,
payload BYTEA NOT NULL,
msg_type TEXT,
server_id TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_stream_id
ON document_update_history(document_id, stream_id);
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_seq
ON document_update_history(document_id, seq);
CREATE INDEX IF NOT EXISTS idx_update_history_document_seq
ON document_update_history(document_id, seq);
-- Add 'guest' as a valid provider for guest mode login
ALTER TABLE users DROP CONSTRAINT IF EXISTS users_provider_check;
ALTER TABLE users ADD CONSTRAINT users_provider_check CHECK (provider IN ('google', 'github', 'guest'));

View File

@@ -15,6 +15,7 @@ services:
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
- ./backend/scripts/init.sql:/docker-entrypoint-initdb.d/init.sql - ./backend/scripts/init.sql:/docker-entrypoint-initdb.d/init.sql
command: postgres -c shared_buffers=128MB -c max_connections=50
healthcheck: healthcheck:
test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"] test: ["CMD-SHELL", "pg_isready -U ${POSTGRES_USER} -d ${POSTGRES_DB}"]
interval: 10s interval: 10s
@@ -24,13 +25,33 @@ services:
redis: redis:
image: redis:7-alpine image: redis:7-alpine
container_name: realtime-collab-redis container_name: realtime-collab-redis
command: ["redis-server", "--appendonly", "yes", "--maxmemory", "64mb", "--maxmemory-policy", "allkeys-lru"]
ports: ports:
- "6379:6379" - "6379:6379"
volumes:
- redis_data:/data
healthcheck: healthcheck:
test: ["CMD", "redis-cli", "ping"] test: ["CMD", "redis-cli", "ping"]
interval: 10s interval: 10s
timeout: 3s timeout: 3s
retries: 5 retries: 5
backend:
build:
context: ./backend
dockerfile: Dockerfile
container_name: realtime-collab-backend
env_file:
- ./backend/.env
ports:
- "8080:8080"
depends_on:
postgres:
condition: service_healthy
redis:
condition: service_healthy
restart: unless-stopped
volumes: volumes:
postgres_data: postgres_data:
redis_data:

2
frontend/.gitignore vendored
View File

@@ -24,3 +24,5 @@ dist-ssr
*.sw? *.sw?
.env .env
.env.local .env.local

View File

@@ -2,7 +2,9 @@
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<link rel="icon" type="image/svg+xml" href="/vite.svg" /> <link rel="icon" type="image/png" sizes="32x32" href="/docnest-icon-32.png" />
<link rel="icon" type="image/png" sizes="64x64" href="/docnest-icon-64.png" />
<link rel="apple-touch-icon" sizes="180x180" href="/apple-touch-icon.png" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Realtime Collab</title> <title>Realtime Collab</title>
<!-- Google Fonts --> <!-- Google Fonts -->

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 744 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@@ -1,6 +1,15 @@
import type { User } from '../types/auth'; import type { User } from '../types/auth';
import { API_BASE_URL, authFetch } from './client'; import { API_BASE_URL, authFetch } from './client';
export async function guestLogin(): Promise<string> {
const res = await fetch(`${API_BASE_URL}/auth/guest`, { method: 'POST' });
if (!res.ok) {
throw new Error('Failed to create guest session');
}
const data = await res.json();
return data.token;
}
export const authApi = { export const authApi = {
getCurrentUser: async (): Promise<User> => { getCurrentUser: async (): Promise<User> => {
const response = await authFetch(`${API_BASE_URL}/auth/me`); const response = await authFetch(`${API_BASE_URL}/auth/me`);

View File

@@ -32,6 +32,10 @@ export async function authFetch(url: string, options?: RequestInit): Promise<Res
// Handle 401: Token expired or invalid // Handle 401: Token expired or invalid
if (response.status === 401) { if (response.status === 401) {
localStorage.removeItem('auth_token'); localStorage.removeItem('auth_token');
const currentPath = window.location.pathname + window.location.search;
if (currentPath !== '/' && currentPath !== '/login') {
sessionStorage.setItem('oauth_redirect', currentPath);
}
window.location.href = '/login'; window.location.href = '/login';
throw new Error('Unauthorized'); throw new Error('Unauthorized');
} }

View File

@@ -53,8 +53,11 @@ export const documentsApi = {
}, },
// Get document Yjs state // Get document Yjs state
getState: async (id: string): Promise<Uint8Array> => { getState: async (id: string, shareToken?: string): Promise<Uint8Array> => {
const response = await authFetch(`${API_BASE_URL}/documents/${id}/state`); const url = shareToken
? `${API_BASE_URL}/documents/${id}/state?share=${shareToken}`
: `${API_BASE_URL}/documents/${id}/state`;
const response = await authFetch(url);
if (!response.ok) throw new Error("Failed to fetch document state"); if (!response.ok) throw new Error("Failed to fetch document state");
const arrayBuffer = await response.arrayBuffer(); const arrayBuffer = await response.arrayBuffer();
return new Uint8Array(arrayBuffer); return new Uint8Array(arrayBuffer);

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.3 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.1 KiB

View File

@@ -7,7 +7,6 @@ import type { KanbanColumn, Task } from "./KanbanBoard.tsx";
interface ColumnProps { interface ColumnProps {
column: KanbanColumn; column: KanbanColumn;
onAddTask: (task: Task) => void; onAddTask: (task: Task) => void;
onMoveTask: (taskId: string, toColumnId: string) => void;
} }
const Column = ({ column, onAddTask }: ColumnProps) => { const Column = ({ column, onAddTask }: ColumnProps) => {
@@ -21,7 +20,7 @@ const Column = ({ column, onAddTask }: ColumnProps) => {
const handleAddTask = () => { const handleAddTask = () => {
if (newTaskTitle.trim()) { if (newTaskTitle.trim()) {
onAddTask({ onAddTask({
id: `task-${Date.now()}`, id: `task-${crypto.randomUUID()}`,
title: newTaskTitle, title: newTaskTitle,
description: "", description: "",
}); });

View File

@@ -6,6 +6,7 @@ import {
useSensor, useSensor,
useSensors, useSensors,
} from '@dnd-kit/core'; } from '@dnd-kit/core';
import { arrayMove } from '@dnd-kit/sortable';
import type { YjsProviders } from "../../lib/yjs"; import type { YjsProviders } from "../../lib/yjs";
import Column from "./Column.tsx"; import Column from "./Column.tsx";
@@ -71,17 +72,44 @@ const KanbanBoard = ({ providers }: KanbanBoardProps) => {
if (columnIndex !== -1) { if (columnIndex !== -1) {
providers.ydoc.transact(() => { providers.ydoc.transact(() => {
const column = cols[columnIndex] as KanbanColumn; const column = cols[columnIndex] as KanbanColumn;
column.tasks.push(task); const nextTasks = [...column.tasks, task];
const nextColumn = { ...column, tasks: nextTasks };
yarray.delete(columnIndex, 1); yarray.delete(columnIndex, 1);
yarray.insert(columnIndex, [column]); yarray.insert(columnIndex, [nextColumn]);
}); });
} }
}; };
const replaceColumn = (index: number, column: KanbanColumn) => {
const yarray = providers.ydoc.getArray("kanban-columns");
yarray.delete(index, 1);
yarray.insert(index, [column]);
};
const findColumnByTaskId = (taskId: string) =>
columns.find((col) => col.tasks.some((task) => task.id === taskId));
const reorderTask = (columnId: string, fromIndex: number, toIndex: number) => {
if (fromIndex === toIndex || fromIndex < 0 || toIndex < 0) return;
const yarray = providers.ydoc.getArray("kanban-columns");
const cols = yarray.toArray();
const columnIndex = cols.findIndex((col: any) => col.id === columnId);
if (columnIndex === -1) return;
const column = cols[columnIndex] as KanbanColumn;
const nextTasks = arrayMove(column.tasks, fromIndex, toIndex);
const nextColumn = { ...column, tasks: nextTasks };
providers.ydoc.transact(() => {
replaceColumn(columnIndex, nextColumn);
});
};
const moveTask = ( const moveTask = (
fromColumnId: string, fromColumnId: string,
toColumnId: string, toColumnId: string,
taskId: string taskId: string,
overTaskId?: string
) => { ) => {
const yarray = providers.ydoc.getArray("kanban-columns"); const yarray = providers.ydoc.getArray("kanban-columns");
const cols = yarray.toArray(); const cols = yarray.toArray();
@@ -91,18 +119,30 @@ const KanbanBoard = ({ providers }: KanbanBoardProps) => {
if (fromIndex !== -1 && toIndex !== -1) { if (fromIndex !== -1 && toIndex !== -1) {
providers.ydoc.transact(() => { providers.ydoc.transact(() => {
const fromCol = { ...(cols[fromIndex] as KanbanColumn) }; const fromCol = cols[fromIndex] as KanbanColumn;
const toCol = { ...(cols[toIndex] as KanbanColumn) }; const toCol = cols[toIndex] as KanbanColumn;
const nextFromTasks = [...fromCol.tasks];
const nextToTasks = fromIndex === toIndex ? nextFromTasks : [...toCol.tasks];
const taskIndex = fromCol.tasks.findIndex((t: Task) => t.id === taskId); const taskIndex = nextFromTasks.findIndex((t: Task) => t.id === taskId);
if (taskIndex !== -1) { if (taskIndex !== -1) {
const [task] = fromCol.tasks.splice(taskIndex, 1); const [task] = nextFromTasks.splice(taskIndex, 1);
toCol.tasks.push(task); const insertIndex =
overTaskId && overTaskId !== toColumnId
? nextToTasks.findIndex((t: Task) => t.id === overTaskId)
: -1;
yarray.delete(fromIndex, 1); if (insertIndex >= 0) {
yarray.insert(fromIndex, [fromCol]); nextToTasks.splice(insertIndex, 0, task);
yarray.delete(toIndex, 1); } else {
yarray.insert(toIndex, [toCol]); nextToTasks.push(task);
}
const nextFromCol = { ...fromCol, tasks: nextFromTasks };
const nextToCol = { ...toCol, tasks: nextToTasks };
replaceColumn(fromIndex, nextFromCol);
replaceColumn(toIndex, nextToCol);
} }
}); });
} }
@@ -114,16 +154,28 @@ const KanbanBoard = ({ providers }: KanbanBoardProps) => {
if (!over) return; if (!over) return;
const taskId = active.id as string; const taskId = active.id as string;
const targetColumnId = over.id as string; const overId = over.id as string;
// Find which column the task is currently in // Find which column the task is currently in
const fromColumn = columns.find(col => const fromColumn = findColumnByTaskId(taskId);
col.tasks.some(task => task.id === taskId) if (!fromColumn) return;
);
if (fromColumn && fromColumn.id !== targetColumnId) { const overColumn =
moveTask(fromColumn.id, targetColumnId, taskId); columns.find((col) => col.id === overId) || findColumnByTaskId(overId);
if (!overColumn) return;
if (fromColumn.id === overColumn.id) {
// Reorder within the same column
const oldIndex = fromColumn.tasks.findIndex((task) => task.id === taskId);
const newIndex = fromColumn.tasks.findIndex((task) => task.id === overId);
if (newIndex !== -1 && oldIndex !== -1) {
reorderTask(fromColumn.id, oldIndex, newIndex);
}
return;
} }
// Move to a different column
moveTask(fromColumn.id, overColumn.id, taskId, overId);
}; };
return ( return (
@@ -134,9 +186,6 @@ const KanbanBoard = ({ providers }: KanbanBoardProps) => {
key={column.id} key={column.id}
column={column} column={column}
onAddTask={(task) => addTask(column.id, task)} onAddTask={(task) => addTask(column.id, task)}
onMoveTask={(taskId, toColumnId) =>
moveTask(column.id, toColumnId, taskId)
}
/> />
))} ))}
</div> </div>

View File

@@ -1,7 +1,8 @@
import { useAuth } from '@/contexts/AuthContext'; import { useAuth } from '@/contexts/AuthContext';
import { Button } from '@/components/ui/button'; import { Button } from '@/components/ui/button';
import PixelIcon from '@/components/PixelIcon/PixelIcon';
import { LogOut } from 'lucide-react'; import { LogOut } from 'lucide-react';
import DocNestLogo from '@/assets/docnest/docnest-icon-128.png';
import ThemeToggle from '@/components/ThemeToggle';
function Navbar() { function Navbar() {
const { user, logout } = useAuth(); const { user, logout } = useAuth();
@@ -35,12 +36,18 @@ function Navbar() {
gap-2 gap-2
" "
> >
<PixelIcon name="gem" size={18} color="hsl(var(--brand-teal))" /> <img
src={DocNestLogo}
alt="DocNest"
className="w-6 h-6"
style={{ imageRendering: 'pixelated' }}
/>
DocNest DocNest
</a> </a>
{/* User Section */} {/* User Section */}
<div className="flex items-center gap-4"> <div className="flex items-center gap-4">
<ThemeToggle className="shadow-soft hover:shadow-card transition-all duration-150" />
{user.avatar_url && ( {user.avatar_url && (
<img <img
src={user.avatar_url} src={user.avatar_url}

View File

@@ -24,7 +24,7 @@ function ProtectedRoute({ children }: ProtectedRouteProps) {
} }
if (!user) { if (!user) {
return <Navigate to={`/login?redirect=${location.pathname}`} replace />; return <Navigate to={`/login?redirect=${encodeURIComponent(location.pathname + location.search)}`} replace />;
} }
return <>{children}</>; return <>{children}</>;

View File

@@ -6,12 +6,19 @@ import './ShareModal.css';
interface ShareModalProps { interface ShareModalProps {
documentId: string; documentId: string;
documentType?: 'editor' | 'kanban';
onClose: () => void; onClose: () => void;
currentPermission?: string; currentPermission?: string;
currentRole?: string; currentRole?: string;
} }
function ShareModal({ documentId, onClose, currentPermission, currentRole }: ShareModalProps) { function ShareModal({
documentId,
documentType = 'editor',
onClose,
currentPermission,
currentRole,
}: ShareModalProps) {
const [activeTab, setActiveTab] = useState<'users' | 'link'>('users'); const [activeTab, setActiveTab] = useState<'users' | 'link'>('users');
const [shares, setShares] = useState<DocumentShareWithUser[]>([]); const [shares, setShares] = useState<DocumentShareWithUser[]>([]);
const [shareLink, setShareLink] = useState<ShareLink | null>(null); const [shareLink, setShareLink] = useState<ShareLink | null>(null);
@@ -24,7 +31,7 @@ function ShareModal({ documentId, onClose, currentPermission, currentRole }: Sha
const [permission, setPermission] = useState<'view' | 'edit'>('view'); const [permission, setPermission] = useState<'view' | 'edit'>('view');
// Form state for link sharing // Form state for link sharing
const [linkPermission, setLinkPermission] = useState<'view' | 'edit'>('view'); const [linkPermission, setLinkPermission] = useState<'view' | 'edit'>('edit');
const [copied, setCopied] = useState(false); const [copied, setCopied] = useState(false);
// Load shares on mount // Load shares on mount
@@ -138,7 +145,7 @@ function ShareModal({ documentId, onClose, currentPermission, currentRole }: Sha
const handleCopyLink = () => { const handleCopyLink = () => {
if (!shareLink) return; if (!shareLink) return;
const url = `${window.location.origin}/editor/${documentId}?share=${shareLink.token}`; const url = `${window.location.origin}/${documentType}/${documentId}?share=${shareLink.token}`;
navigator.clipboard.writeText(url); navigator.clipboard.writeText(url);
setCopied(true); setCopied(true);
setTimeout(() => setCopied(false), 2000); setTimeout(() => setCopied(false), 2000);
@@ -278,7 +285,7 @@ function ShareModal({ documentId, onClose, currentPermission, currentRole }: Sha
<div className="link-box"> <div className="link-box">
<input <input
type="text" type="text"
value={`${window.location.origin}/editor/${documentId}?share=${shareLink.token}`} value={`${window.location.origin}/${documentType}/${documentId}?share=${shareLink.token}`}
readOnly readOnly
className="link-input" className="link-input"
/> />

View File

@@ -0,0 +1,35 @@
import { useEffect, useState } from "react";
import { Moon, Sun } from "lucide-react";
import { applyTheme, getPreferredTheme, type ThemeMode } from "@/lib/theme";
import { Button } from "@/components/ui/button";
type ThemeToggleProps = {
className?: string;
size?: "sm" | "default" | "icon";
};
function ThemeToggle({ className, size = "icon" }: ThemeToggleProps) {
const [theme, setTheme] = useState<ThemeMode>(() => getPreferredTheme());
useEffect(() => {
applyTheme(theme);
}, [theme]);
const nextTheme: ThemeMode = theme === "dark" ? "light" : "dark";
return (
<Button
type="button"
variant="outline"
size={size}
onClick={() => setTheme(nextTheme)}
className={className}
aria-label={`Switch to ${nextTheme} mode`}
title={`Switch to ${nextTheme} mode`}
>
{theme === "dark" ? <Sun className="h-4 w-4" /> : <Moon className="h-4 w-4" />}
</Button>
);
}
export default ThemeToggle;

View File

@@ -1,4 +1,4 @@
import { createContext, useContext, useState, useEffect } from 'react'; import { createContext, useContext, useState, useEffect, useCallback } from 'react';
import type { ReactNode } from 'react'; import type { ReactNode } from 'react';
import type { User, AuthContextType } from '../types/auth'; import type { User, AuthContextType } from '../types/auth';
import { authApi } from '../api/auth'; import { authApi } from '../api/auth';
@@ -40,7 +40,7 @@ export function AuthProvider({ children }: { children: ReactNode }) {
initAuth(); initAuth();
}, []); }, []);
const login = async (newToken: string) => { const login = useCallback(async (newToken: string) => {
try { try {
setLoading(true); setLoading(true);
setError(null); setError(null);
@@ -60,7 +60,7 @@ export function AuthProvider({ children }: { children: ReactNode }) {
} finally { } finally {
setLoading(false); setLoading(false);
} }
}; }, []);
const logout = () => { const logout = () => {
localStorage.removeItem('auth_token'); localStorage.removeItem('auth_token');

View File

@@ -157,10 +157,34 @@ export const useYjsDocument = (documentId: string, shareToken?: string) => {
setSynced(true); setSynced(true);
}); });
// Connection stability monitoring with reconnection limits
let reconnectCount = 0;
const maxReconnects = 10;
yjsProviders.websocketProvider.on( yjsProviders.websocketProvider.on(
"status", "status",
(event: { status: string }) => { (event: { status: string }) => {
console.log("WebSocket status:", event.status); console.log("WebSocket status:", event.status);
if (event.status === "disconnected") {
reconnectCount++;
if (reconnectCount >= maxReconnects) {
console.error(
"Max reconnection attempts reached. Please refresh the page."
);
// Could optionally show a user notification here
} else {
console.log(
`Reconnection attempt ${reconnectCount}/${maxReconnects}`
);
}
} else if (event.status === "connected") {
// Reset counter on successful connection
if (reconnectCount > 0) {
console.log("Reconnected successfully, resetting counter");
}
reconnectCount = 0;
}
} }
); );

View File

@@ -27,6 +27,28 @@
--ring: 214 89% 52%; --ring: 214 89% 52%;
--radius: 0.75rem; --radius: 0.75rem;
} }
.dark {
--background: 215 26% 7%;
--foreground: 0 0% 98%;
--card: 215 21% 11%;
--card-foreground: 0 0% 98%;
--popover: 215 21% 11%;
--popover-foreground: 0 0% 98%;
--primary: 213 93% 60%;
--primary-foreground: 0 0% 100%;
--secondary: 173 70% 42%;
--secondary-foreground: 0 0% 100%;
--muted: 215 15% 15%;
--muted-foreground: 215 10% 58%;
--accent: 197 100% 68%;
--accent-foreground: 215 26% 7%;
--destructive: 0 70% 52%;
--destructive-foreground: 0 0% 100%;
--border: 215 12% 21%;
--input: 215 12% 21%;
--ring: 213 93% 60%;
}
} }
* { * {
@@ -81,6 +103,47 @@
--pixel-text-muted: #64748B; --pixel-text-muted: #64748B;
} }
.dark {
--surface: 215 21% 11%;
--surface-muted: 215 15% 15%;
--text-primary: 0 0% 98%;
--text-secondary: 215 15% 82%;
--text-muted: 215 10% 70%;
--brand: 213 93% 60%;
--brand-dark: 213 90% 52%;
--brand-teal: 173 70% 42%;
--brand-teal-dark: 173 68% 34%;
--shadow-sm: 0 1px 2px rgba(0, 0, 0, 0.45);
--shadow-md: 0 12px 30px rgba(0, 0, 0, 0.55);
--shadow-lg: 0 20px 50px rgba(0, 0, 0, 0.65);
--focus-ring: 0 0 0 3px rgba(88, 166, 255, 0.35);
--gradient-hero: linear-gradient(120deg, #0d1117 0%, #111827 55%, #161b22 100%);
--gradient-accent: linear-gradient(120deg, #2f81f7 0%, #14b8a6 100%);
--pixel-purple-deep: #0b1f4b;
--pixel-purple-bright: #2f81f7;
--pixel-pink-vibrant: #58a6ff;
--pixel-cyan-bright: #14b8a6;
--pixel-orange-warm: #f59e0b;
--pixel-yellow-gold: #fbbf24;
--pixel-green-lime: #22c55e;
--pixel-green-forest: #16a34a;
--pixel-bg-dark: #0d1117;
--pixel-bg-medium: #161b22;
--pixel-bg-light: #1f2937;
--pixel-panel: #0f172a;
--pixel-white: #e5e7eb;
--pixel-shadow-dark: rgba(0, 0, 0, 0.5);
--pixel-outline: #30363d;
--pixel-text-primary: #e5e7eb;
--pixel-text-secondary: #c9d1d9;
--pixel-text-muted: #8b949e;
}
body { body {
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen',
'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif; 'Ubuntu', 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif;

37
frontend/src/lib/theme.ts Normal file
View File

@@ -0,0 +1,37 @@
export type ThemeMode = "light" | "dark";
export const getStoredTheme = (): ThemeMode | null => {
try {
const value = localStorage.getItem("theme");
if (value === "light" || value === "dark") {
return value;
}
} catch {
// Ignore storage access errors
}
return null;
};
export const getPreferredTheme = (): ThemeMode => {
const stored = getStoredTheme();
if (stored) return stored;
if (typeof window !== "undefined" && window.matchMedia) {
return window.matchMedia("(prefers-color-scheme: dark)").matches
? "dark"
: "light";
}
return "light";
};
export const applyTheme = (theme: ThemeMode) => {
if (typeof document === "undefined") return;
document.documentElement.classList.toggle("dark", theme === "dark");
try {
localStorage.setItem("theme", theme);
} catch {
// Ignore storage access errors
}
};

View File

@@ -30,7 +30,7 @@ export const createYjsDocument = async (
// Load initial state from database BEFORE connecting providers // Load initial state from database BEFORE connecting providers
try { try {
const state = await documentsApi.getState(documentId); const state = await documentsApi.getState(documentId, shareToken);
if (state && state.length > 0) { if (state && state.length > 0) {
Y.applyUpdate(ydoc, state); Y.applyUpdate(ydoc, state);
console.log('✓ Loaded document state from database'); console.log('✓ Loaded document state from database');
@@ -51,7 +51,10 @@ export const createYjsDocument = async (
wsUrl, wsUrl,
documentId, documentId,
ydoc, ydoc,
{ params: wsParams } {
params: wsParams,
maxBackoffTime: 10000, // Max 10s between reconnect attempts
}
); );
// Awareness for cursors and presence // Awareness for cursors and presence

View File

@@ -1,8 +1,11 @@
import React from "react"; import React from "react";
import ReactDOM from "react-dom/client"; import ReactDOM from "react-dom/client";
import App from "./App.tsx"; import App from "./App.tsx";
import { applyTheme, getPreferredTheme } from "./lib/theme";
import "./index.css"; import "./index.css";
applyTheme(getPreferredTheme());
ReactDOM.createRoot(document.getElementById("root")!).render( ReactDOM.createRoot(document.getElementById("root")!).render(
<React.StrictMode> <React.StrictMode>
<App /> <App />

View File

@@ -1,4 +1,4 @@
import { useEffect, useState } from 'react'; import { useEffect, useRef, useState } from 'react';
import { useNavigate, useSearchParams } from 'react-router-dom'; import { useNavigate, useSearchParams } from 'react-router-dom';
import { useAuth } from '../contexts/AuthContext'; import { useAuth } from '../contexts/AuthContext';
@@ -7,11 +7,19 @@ function AuthCallback() {
const navigate = useNavigate(); const navigate = useNavigate();
const { login } = useAuth(); const { login } = useAuth();
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const processedRef = useRef(false);
useEffect(() => { useEffect(() => {
if (processedRef.current) return;
processedRef.current = true;
const handleCallback = async () => { const handleCallback = async () => {
const token = searchParams.get('token'); const token = searchParams.get('token');
const redirect = searchParams.get('redirect') || '/'; const redirect =
searchParams.get('redirect') ||
sessionStorage.getItem('oauth_redirect') ||
'/';
sessionStorage.removeItem('oauth_redirect');
if (!token) { if (!token) {
setError('No authentication token received'); setError('No authentication token received');

View File

@@ -1,3 +1,4 @@
import { Eye } from "lucide-react";
import { useState } from "react"; import { useState } from "react";
import { useNavigate, useParams, useSearchParams } from "react-router-dom"; import { useNavigate, useParams, useSearchParams } from "react-router-dom";
import Editor from "../components/Editor/Editor.tsx"; import Editor from "../components/Editor/Editor.tsx";
@@ -6,14 +7,13 @@ import UserList from "../components/Presence/UserList.tsx";
import ShareModal from "../components/Share/ShareModal.tsx"; import ShareModal from "../components/Share/ShareModal.tsx";
import VersionHistoryPanel from "../components/VersionHistory/VersionHistoryPanel.tsx"; import VersionHistoryPanel from "../components/VersionHistory/VersionHistoryPanel.tsx";
import { useYjsDocument } from "../hooks/useYjsDocument.ts"; import { useYjsDocument } from "../hooks/useYjsDocument.ts";
import { Eye } from "lucide-react";
const EditorPage = () => { const EditorPage = () => {
const { id } = useParams<{ id: string }>(); const { id } = useParams<{ id: string }>();
const [searchParams] = useSearchParams(); const [searchParams] = useSearchParams();
const navigate = useNavigate(); const navigate = useNavigate();
const shareToken = searchParams.get('share') || undefined; const shareToken = searchParams.get('share') || undefined;
const { providers, synced, permission, role } = useYjsDocument(id!, shareToken); const { providers, permission, role } = useYjsDocument(id!, shareToken);
const [showShareModal, setShowShareModal] = useState(false); const [showShareModal, setShowShareModal] = useState(false);
const [showVersionHistory, setShowVersionHistory] = useState(false); const [showVersionHistory, setShowVersionHistory] = useState(false);
@@ -58,6 +58,7 @@ const EditorPage = () => {
{showShareModal && ( {showShareModal && (
<ShareModal <ShareModal
documentId={id!} documentId={id!}
documentType="editor"
onClose={() => setShowShareModal(false)} onClose={() => setShowShareModal(false)}
currentPermission={permission || undefined} currentPermission={permission || undefined}
currentRole={role || undefined} currentRole={role || undefined}

View File

@@ -1,9 +1,9 @@
import { useState } from "react"; import { useState } from "react";
import { useNavigate, useParams, useSearchParams } from "react-router-dom"; import { useNavigate, useParams, useSearchParams } from "react-router-dom";
import KanbanBoard from "../components/Kanban/KanbanBoard.tsx"; import KanbanBoard from "../components/Kanban/KanbanBoard.tsx";
import Navbar from "../components/Navbar.tsx";
import UserList from "../components/Presence/UserList.tsx"; import UserList from "../components/Presence/UserList.tsx";
import ShareModal from "../components/Share/ShareModal.tsx"; import ShareModal from "../components/Share/ShareModal.tsx";
import Navbar from "../components/Navbar.tsx";
import { useYjsDocument } from "../hooks/useYjsDocument.ts"; import { useYjsDocument } from "../hooks/useYjsDocument.ts";
const KanbanPage = () => { const KanbanPage = () => {
@@ -11,7 +11,7 @@ const KanbanPage = () => {
const [searchParams] = useSearchParams(); const [searchParams] = useSearchParams();
const navigate = useNavigate(); const navigate = useNavigate();
const shareToken = searchParams.get('share') || undefined; const shareToken = searchParams.get('share') || undefined;
const { providers, synced } = useYjsDocument(id!, shareToken); const { providers } = useYjsDocument(id!, shareToken);
const [showShareModal, setShowShareModal] = useState(false); const [showShareModal, setShowShareModal] = useState(false);
if (!providers) { if (!providers) {
@@ -42,7 +42,11 @@ const KanbanPage = () => {
</div> </div>
{showShareModal && ( {showShareModal && (
<ShareModal documentId={id!} onClose={() => setShowShareModal(false)} /> <ShareModal
documentId={id!}
documentType="kanban"
onClose={() => setShowShareModal(false)}
/>
)} )}
</div> </div>
); );

View File

@@ -6,6 +6,13 @@
background: hsl(var(--background)); background: hsl(var(--background));
} }
.landing-theme-toggle {
position: fixed;
top: 24px;
right: 24px;
z-index: 20;
}
/* ======================================== /* ========================================
Hero Section Hero Section
======================================== */ ======================================== */
@@ -50,6 +57,12 @@
margin-bottom: 1.5rem; margin-bottom: 1.5rem;
} }
.hero-logo-icon {
width: 28px;
height: 28px;
image-rendering: pixelated;
}
.hero-brand { .hero-brand {
font-size: 0.95rem; font-size: 0.95rem;
font-weight: 700; font-weight: 700;
@@ -128,6 +141,19 @@
background: hsl(var(--surface)); background: hsl(var(--surface));
} }
.landing-login-button.guest {
background: transparent;
border: 1px dashed hsl(var(--border));
color: hsl(var(--text-secondary));
font-size: 0.9rem;
padding: 0.6rem 1.5rem;
}
.landing-login-button.guest:hover {
color: hsl(var(--text-primary));
border-style: solid;
}
.landing-login-button.large { .landing-login-button.large {
padding: 1rem 2rem; padding: 1rem 2rem;
font-size: 1.05rem; font-size: 1.05rem;

View File

@@ -1,9 +1,19 @@
import { useState } from 'react';
import { useNavigate } from 'react-router-dom';
import { useAuth } from '../contexts/AuthContext';
import { guestLogin } from '../api/auth';
import FloatingGem from '../components/PixelSprites/FloatingGem'; import FloatingGem from '../components/PixelSprites/FloatingGem';
import PixelIcon from '../components/PixelIcon/PixelIcon'; import PixelIcon from '../components/PixelIcon/PixelIcon';
import DocNestLogo from '../assets/docnest/docnest-icon-128.png';
import ThemeToggle from '../components/ThemeToggle';
import { API_BASE_URL } from '../config'; import { API_BASE_URL } from '../config';
import './LandingPage.css'; import './LandingPage.css';
function LandingPage() { function LandingPage() {
const { login } = useAuth();
const navigate = useNavigate();
const [guestLoading, setGuestLoading] = useState(false);
const handleGoogleLogin = () => { const handleGoogleLogin = () => {
window.location.href = `${API_BASE_URL}/auth/google`; window.location.href = `${API_BASE_URL}/auth/google`;
}; };
@@ -12,8 +22,24 @@ function LandingPage() {
window.location.href = `${API_BASE_URL}/auth/github`; window.location.href = `${API_BASE_URL}/auth/github`;
}; };
const handleGuestLogin = async () => {
try {
setGuestLoading(true);
const token = await guestLogin();
await login(token);
navigate('/');
} catch (err) {
console.error('Guest login failed:', err);
} finally {
setGuestLoading(false);
}
};
return ( return (
<div className="landing-page"> <div className="landing-page">
<div className="landing-theme-toggle">
<ThemeToggle />
</div>
{/* Hero Section */} {/* Hero Section */}
<section className="landing-hero"> <section className="landing-hero">
<div className="hero-gem hero-gem-one"> <div className="hero-gem hero-gem-one">
@@ -26,7 +52,11 @@ function LandingPage() {
<div className="hero-grid"> <div className="hero-grid">
<div className="hero-content"> <div className="hero-content">
<div className="hero-logo"> <div className="hero-logo">
<PixelIcon name="gem" size={28} color="hsl(var(--brand-teal))" /> <img
src={DocNestLogo}
alt="DocNest"
className="hero-logo-icon"
/>
<span className="hero-brand">DocNest</span> <span className="hero-brand">DocNest</span>
</div> </div>
@@ -52,6 +82,13 @@ function LandingPage() {
<span>Continue with GitHub</span> <span>Continue with GitHub</span>
</button> </button>
</div> </div>
<button
className="landing-login-button guest"
onClick={handleGuestLogin}
disabled={guestLoading}
>
{guestLoading ? 'Entering...' : 'Try as Guest'}
</button>
<p className="hero-note">No credit card required.</p> <p className="hero-note">No credit card required.</p>
</div> </div>
</div> </div>

View File

@@ -5,6 +5,7 @@
justify-content: center; justify-content: center;
background: var(--gradient-hero); background: var(--gradient-hero);
padding: 24px; padding: 24px;
position: relative;
} }
.login-container { .login-container {
@@ -18,11 +19,32 @@
text-align: center; text-align: center;
} }
.login-theme-toggle {
position: fixed;
top: 24px;
right: 24px;
z-index: 20;
}
.login-brand {
display: flex;
align-items: center;
justify-content: center;
gap: 12px;
margin-bottom: 8px;
}
.login-logo {
width: 32px;
height: 32px;
image-rendering: pixelated;
}
.login-title { .login-title {
font-size: 32px; font-size: 32px;
font-weight: 700; font-weight: 700;
color: hsl(var(--text-primary)); color: hsl(var(--text-primary));
margin: 0 0 8px 0; margin: 0;
} }
.login-subtitle { .login-subtitle {
@@ -83,3 +105,36 @@
transform: translateY(0); transform: translateY(0);
box-shadow: var(--shadow-sm); box-shadow: var(--shadow-sm);
} }
.login-divider {
display: flex;
align-items: center;
gap: 12px;
margin: 4px 0;
}
.login-divider::before,
.login-divider::after {
content: '';
flex: 1;
height: 1px;
background: hsl(var(--border));
}
.login-divider span {
font-size: 13px;
color: hsl(var(--text-secondary));
white-space: nowrap;
}
.guest-button {
background: transparent;
border: 1px dashed hsl(var(--border));
color: hsl(var(--text-secondary));
}
.guest-button:hover {
background: hsl(var(--surface-hover, var(--border) / 0.1));
color: hsl(var(--text-primary));
border-style: solid;
}

View File

@@ -1,12 +1,17 @@
import { useEffect } from 'react'; import { useEffect, useState } from 'react';
import { useNavigate } from 'react-router-dom'; import { useNavigate, useSearchParams } from 'react-router-dom';
import { useAuth } from '../contexts/AuthContext'; import { useAuth } from '../contexts/AuthContext';
import { guestLogin } from '../api/auth';
import { API_BASE_URL } from '../config'; import { API_BASE_URL } from '../config';
import DocNestLogo from '../assets/docnest/docnest-icon-128.png';
import ThemeToggle from '../components/ThemeToggle';
import './LoginPage.css'; import './LoginPage.css';
function LoginPage() { function LoginPage() {
const { user, loading } = useAuth(); const { user, loading, login } = useAuth();
const navigate = useNavigate(); const navigate = useNavigate();
const [searchParams] = useSearchParams();
const [guestLoading, setGuestLoading] = useState(false);
useEffect(() => { useEffect(() => {
if (!loading && user) { if (!loading && user) {
@@ -14,12 +19,34 @@ function LoginPage() {
} }
}, [user, loading, navigate]); }, [user, loading, navigate]);
const saveRedirectAndGo = (oauthUrl: string) => {
const redirect = searchParams.get('redirect');
if (redirect) {
sessionStorage.setItem('oauth_redirect', decodeURIComponent(redirect));
}
window.location.href = oauthUrl;
};
const handleGoogleLogin = () => { const handleGoogleLogin = () => {
window.location.href = `${API_BASE_URL}/auth/google`; saveRedirectAndGo(`${API_BASE_URL}/auth/google`);
}; };
const handleGitHubLogin = () => { const handleGitHubLogin = () => {
window.location.href = `${API_BASE_URL}/auth/github`; saveRedirectAndGo(`${API_BASE_URL}/auth/github`);
};
const handleGuestLogin = async () => {
try {
setGuestLoading(true);
const token = await guestLogin();
await login(token);
const redirect = searchParams.get('redirect');
navigate(redirect ? decodeURIComponent(redirect) : '/');
} catch (err) {
console.error('Guest login failed:', err);
} finally {
setGuestLoading(false);
}
}; };
if (loading) { if (loading) {
@@ -34,8 +61,14 @@ function LoginPage() {
return ( return (
<div className="login-page"> <div className="login-page">
<div className="login-theme-toggle">
<ThemeToggle />
</div>
<div className="login-container"> <div className="login-container">
<h1 className="login-title">DocNest</h1> <div className="login-brand">
<img src={DocNestLogo} alt="DocNest" className="login-logo" />
<h1 className="login-title">DocNest</h1>
</div>
<p className="login-subtitle">Collaborate in real time with your team</p> <p className="login-subtitle">Collaborate in real time with your team</p>
<div className="login-buttons"> <div className="login-buttons">
@@ -76,6 +109,18 @@ function LoginPage() {
</svg> </svg>
Continue with GitHub Continue with GitHub
</button> </button>
<div className="login-divider">
<span></span>
</div>
<button
className="login-button guest-button"
onClick={handleGuestLogin}
disabled={guestLoading}
>
{guestLoading ? 'Entering...' : 'Continue as Guest'}
</button>
</div> </div>
</div> </div>
</div> </div>

49
k3s/backend.yaml Normal file
View File

@@ -0,0 +1,49 @@
apiVersion: apps/v1
kind: Deployment
metadata:
name: realtime-collab-backend
spec:
replicas: 1
selector:
matchLabels:
app: realtime-collab-backend
template:
metadata:
labels:
app: realtime-collab-backend
spec:
containers:
- name: backend
image: realtime-collab-backend:latest
imagePullPolicy: Never
ports:
- containerPort: 8080
envFrom:
- secretRef:
name: realtime-collab-secret
resources:
requests:
memory: "32Mi"
cpu: "50m"
limits:
memory: "128Mi"
cpu: "300m"
readinessProbe:
httpGet:
path: /health
port: 8080
initialDelaySeconds: 10
periodSeconds: 10
---
apiVersion: v1
kind: Service
metadata:
name: realtime-collab-backend-svc
spec:
type: NodePort
selector:
app: realtime-collab-backend
ports:
- port: 8080
targetPort: 8080
nodePort: 30080

178
k3s/configmap.yaml Normal file
View File

@@ -0,0 +1,178 @@
apiVersion: v1
kind: ConfigMap
metadata:
name: postgres-init-sql
data:
init.sql: |
CREATE EXTENSION IF NOT EXISTS "uuid-ossp";
CREATE EXTENSION IF NOT EXISTS "pgcrypto";
CREATE TABLE IF NOT EXISTS documents (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
name VARCHAR(255) NOT NULL,
type VARCHAR(50) NOT NULL CHECK (type IN ('editor', 'kanban')),
yjs_state BYTEA,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX idx_documents_type ON documents(type);
CREATE INDEX idx_documents_created_at ON documents(created_at DESC);
CREATE TABLE IF NOT EXISTS document_updates (
id SERIAL PRIMARY KEY,
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
update BYTEA NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW()
);
CREATE INDEX idx_updates_document_id ON document_updates(document_id);
CREATE INDEX idx_updates_created_at ON document_updates(created_at DESC);
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
email VARCHAR(255) NOT NULL,
name VARCHAR(255) NOT NULL,
avatar_url TEXT,
provider VARCHAR(50) NOT NULL CHECK (provider IN ('google', 'github', 'guest')),
provider_user_id VARCHAR(255) NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
last_login_at TIMESTAMPTZ,
UNIQUE(provider, provider_user_id)
);
CREATE INDEX idx_users_email ON users(email);
CREATE INDEX idx_users_provider ON users(provider, provider_user_id);
CREATE TABLE IF NOT EXISTS sessions (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
token_hash VARCHAR(64) NOT NULL,
expires_at TIMESTAMPTZ NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
user_agent TEXT,
ip_address VARCHAR(45),
UNIQUE(token_hash)
);
CREATE INDEX idx_sessions_user_id ON sessions(user_id);
CREATE INDEX idx_sessions_token_hash ON sessions(token_hash);
CREATE INDEX idx_sessions_expires_at ON sessions(expires_at);
ALTER TABLE documents ADD COLUMN IF NOT EXISTS owner_id UUID REFERENCES users(id) ON DELETE SET NULL;
CREATE INDEX IF NOT EXISTS idx_documents_owner_id ON documents(owner_id);
CREATE TABLE IF NOT EXISTS document_shares (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
permission VARCHAR(20) NOT NULL CHECK (permission IN ('view', 'edit')),
created_at TIMESTAMPTZ DEFAULT NOW(),
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
UNIQUE(document_id, user_id)
);
CREATE INDEX idx_shares_document_id ON document_shares(document_id);
CREATE INDEX idx_shares_user_id ON document_shares(user_id);
CREATE INDEX idx_shares_permission ON document_shares(document_id, permission);
ALTER TABLE documents ADD COLUMN IF NOT EXISTS share_token VARCHAR(255);
ALTER TABLE documents ADD COLUMN IF NOT EXISTS is_public BOOLEAN DEFAULT false NOT NULL;
CREATE INDEX IF NOT EXISTS idx_documents_share_token ON documents(share_token) WHERE share_token IS NOT NULL;
CREATE INDEX IF NOT EXISTS idx_documents_is_public ON documents(is_public) WHERE is_public = true;
ALTER TABLE documents ADD CONSTRAINT check_public_has_token
CHECK (is_public = false OR (is_public = true AND share_token IS NOT NULL));
ALTER TABLE documents ADD COLUMN IF NOT EXISTS share_permission VARCHAR(20) DEFAULT 'edit' CHECK (share_permission IN ('view', 'edit'));
CREATE INDEX IF NOT EXISTS idx_documents_share_permission ON documents(share_permission) WHERE is_public = true;
CREATE TABLE IF NOT EXISTS oauth_tokens (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
provider VARCHAR(50) NOT NULL,
access_token TEXT NOT NULL,
refresh_token TEXT,
token_type VARCHAR(50) DEFAULT 'Bearer',
expires_at TIMESTAMPTZ NOT NULL,
scope TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
CONSTRAINT oauth_tokens_user_id_provider_key UNIQUE (user_id, provider)
);
CREATE INDEX idx_oauth_tokens_user_id ON oauth_tokens(user_id);
CREATE TABLE IF NOT EXISTS document_versions (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
yjs_snapshot BYTEA NOT NULL,
text_preview TEXT,
version_number INTEGER NOT NULL,
created_by UUID REFERENCES users(id) ON DELETE SET NULL,
version_label TEXT,
is_auto_generated BOOLEAN DEFAULT true,
created_at TIMESTAMPTZ DEFAULT NOW(),
CONSTRAINT unique_document_version UNIQUE(document_id, version_number)
);
CREATE INDEX idx_document_versions_document_id ON document_versions(document_id, created_at DESC);
CREATE INDEX idx_document_versions_created_by ON document_versions(created_by);
ALTER TABLE documents ADD COLUMN IF NOT EXISTS version_count INTEGER DEFAULT 0;
ALTER TABLE documents ADD COLUMN IF NOT EXISTS last_snapshot_at TIMESTAMPTZ;
CREATE OR REPLACE FUNCTION get_next_version_number(p_document_id UUID)
RETURNS INTEGER AS $$
DECLARE
next_version INTEGER;
BEGIN
SELECT COALESCE(MAX(version_number), 0) + 1
INTO next_version
FROM document_versions
WHERE document_id = p_document_id;
RETURN next_version;
END;
$$ LANGUAGE plpgsql;
ALTER TABLE users ENABLE ROW LEVEL SECURITY;
ALTER TABLE sessions ENABLE ROW LEVEL SECURITY;
ALTER TABLE oauth_tokens ENABLE ROW LEVEL SECURITY;
ALTER TABLE documents ENABLE ROW LEVEL SECURITY;
ALTER TABLE document_updates ENABLE ROW LEVEL SECURITY;
ALTER TABLE document_shares ENABLE ROW LEVEL SECURITY;
ALTER TABLE document_versions ENABLE ROW LEVEL SECURITY;
CREATE POLICY "Allow all operations on users" ON users FOR ALL USING (true);
CREATE POLICY "Allow all operations on sessions" ON sessions FOR ALL USING (true);
CREATE POLICY "Allow all operations on oauth_tokens" ON oauth_tokens FOR ALL USING (true);
CREATE POLICY "Allow all operations on documents" ON documents FOR ALL USING (true);
CREATE POLICY "Allow all operations on document_updates" ON document_updates FOR ALL USING (true);
CREATE POLICY "Allow all operations on document_shares" ON document_shares FOR ALL USING (true);
CREATE POLICY "Allow all operations on document_versions" ON document_versions FOR ALL USING (true);
CREATE TABLE IF NOT EXISTS stream_checkpoints (
document_id UUID PRIMARY KEY REFERENCES documents(id) ON DELETE CASCADE,
last_stream_id TEXT NOT NULL,
last_seq BIGINT NOT NULL DEFAULT 0,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_stream_checkpoints_updated_at ON stream_checkpoints(updated_at DESC);
CREATE TABLE IF NOT EXISTS document_update_history (
id BIGSERIAL PRIMARY KEY,
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
stream_id TEXT NOT NULL,
seq BIGINT NOT NULL,
payload BYTEA NOT NULL,
msg_type TEXT,
server_id TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_stream_id ON document_update_history(document_id, stream_id);
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_seq ON document_update_history(document_id, seq);
CREATE INDEX IF NOT EXISTS idx_update_history_document_seq ON document_update_history(document_id, seq);

69
k3s/postgres.yaml Normal file
View File

@@ -0,0 +1,69 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: postgres-pvc
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 5Gi
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: postgres
spec:
replicas: 1
selector:
matchLabels:
app: postgres
template:
metadata:
labels:
app: postgres
spec:
containers:
- name: postgres
image: postgres:16-alpine
args: ["-c", "shared_buffers=128MB", "-c", "max_connections=50"]
ports:
- containerPort: 5432
envFrom:
- secretRef:
name: realtime-collab-secret
resources:
requests:
memory: "64Mi"
cpu: "100m"
limits:
memory: "256Mi"
cpu: "500m"
volumeMounts:
- name: postgres-data
mountPath: /var/lib/postgresql/data
- name: init-sql
mountPath: /docker-entrypoint-initdb.d
readinessProbe:
exec:
command: ["pg_isready", "-U", "$(POSTGRES_USER)", "-d", "$(POSTGRES_DB)"]
initialDelaySeconds: 10
periodSeconds: 10
volumes:
- name: postgres-data
persistentVolumeClaim:
claimName: postgres-pvc
- name: init-sql
configMap:
name: postgres-init-sql
---
apiVersion: v1
kind: Service
metadata:
name: postgres
spec:
selector:
app: postgres
ports:
- port: 5432
targetPort: 5432

61
k3s/redis.yaml Normal file
View File

@@ -0,0 +1,61 @@
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: redis-pvc
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 1Gi
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: redis
spec:
replicas: 1
selector:
matchLabels:
app: redis
template:
metadata:
labels:
app: redis
spec:
containers:
- name: redis
image: redis:7-alpine
args: ["redis-server", "--appendonly", "yes", "--maxmemory", "64mb", "--maxmemory-policy", "allkeys-lru"]
ports:
- containerPort: 6379
resources:
requests:
memory: "32Mi"
cpu: "50m"
limits:
memory: "128Mi"
cpu: "200m"
volumeMounts:
- name: redis-data
mountPath: /data
readinessProbe:
exec:
command: ["redis-cli", "ping"]
initialDelaySeconds: 5
periodSeconds: 10
volumes:
- name: redis-data
persistentVolumeClaim:
claimName: redis-pvc
---
apiVersion: v1
kind: Service
metadata:
name: redis
spec:
selector:
app: redis
ports:
- port: 6379
targetPort: 6379

25
k3s/secret.example.yaml Normal file
View File

@@ -0,0 +1,25 @@
apiVersion: v1
kind: Secret
metadata:
name: realtime-collab-secret
type: Opaque
stringData:
# Postgres
POSTGRES_USER: "replace"
POSTGRES_PASSWORD: "replace"
POSTGRES_DB: "replace"
# Backend
DATABASE_URL: "postgres://user:pass@postgres:5432/dbname?sslmode=disable"
REDIS_URL: "redis://redis:6379"
JWT_SECRET: "replace"
PORT: "8080"
ENVIRONMENT: "production"
BACKEND_URL: "https://collab.m1ngdaxie.com"
FRONTEND_URL: "https://collab.m1ngdaxie.com"
ALLOWED_ORIGINS: "https://collab.m1ngdaxie.com"
GOOGLE_CLIENT_ID: "replace"
GOOGLE_CLIENT_SECRET: "replace"
GOOGLE_REDIRECT_URL: "https://collab.m1ngdaxie.com/api/auth/google/callback"
GITHUB_CLIENT_ID: "replace"
GITHUB_CLIENT_SECRET: "replace"
GITHUB_REDIRECT_URL: "https://collab.m1ngdaxie.com/api/auth/github/callback"