From 50822600ad7c9db2fd3fa2ec6d3b38c16e042468 Mon Sep 17 00:00:00 2001 From: M1ngdaXie <156019134+M1ngdaXie@users.noreply.github.com> Date: Sun, 8 Mar 2026 17:13:42 -0700 Subject: [PATCH] feat: implement Redis Streams support with stream checkpoints and update history - Added Redis Streams operations to the message bus interface and implementation. - Introduced StreamCheckpoint model to track last processed stream entry per document. - Implemented UpsertStreamCheckpoint and GetStreamCheckpoint methods in the Postgres store. - Created document_update_history table for storing update payloads for recovery and replay. - Developed update persist worker to handle Redis Stream updates and persist them to Postgres. - Enhanced Docker Compose configuration for Redis with persistence. - Updated frontend API to support fetching document state with optional share token. - Added connection stability monitoring in the Yjs document hook. --- .gitignore | 7 +- backend/internal/auth/middleware.go | 38 +-- backend/internal/handlers/document.go | 97 +++++- backend/internal/handlers/document_test.go | 4 +- backend/internal/handlers/share_test.go | 3 +- backend/internal/handlers/websocket.go | 120 ++++++- backend/internal/hub/hub.go | 160 ++++++++- backend/internal/logger/logger.go | 2 +- backend/internal/messagebus/interface.go | 137 ++++++++ backend/internal/messagebus/redis.go | 284 +++++++++++++++- backend/internal/models/stream_checkpoint.go | 15 + backend/internal/store/postgres.go | 9 + backend/internal/store/stream_checkpoint.go | 46 +++ backend/internal/store/testutil.go | 14 +- backend/internal/store/update_history.go | 115 +++++++ .../internal/workers/update_persist_worker.go | 320 ++++++++++++++++++ .../scripts/010_add_stream_checkpoints.sql | 12 + backend/scripts/011_add_update_history.sql | 22 ++ docker-compose.yml | 4 + frontend/src/api/document.ts | 9 +- frontend/src/hooks/useYjsDocument.ts | 24 ++ frontend/src/lib/yjs.ts | 7 +- 22 files changed, 1371 insertions(+), 78 deletions(-) create mode 100644 backend/internal/models/stream_checkpoint.go create mode 100644 backend/internal/store/stream_checkpoint.go create mode 100644 backend/internal/store/update_history.go create mode 100644 backend/internal/workers/update_persist_worker.go create mode 100644 backend/scripts/010_add_stream_checkpoints.sql create mode 100644 backend/scripts/011_add_update_history.sql diff --git a/.gitignore b/.gitignore index 8192af3..6066bf2 100644 --- a/.gitignore +++ b/.gitignore @@ -34,4 +34,9 @@ build/ # Docker volumes and data postgres_data/ -.claude/ \ No newline at end of file +.claude/ + +#test folder profiles +loadtest/pprof + +/docs \ No newline at end of file diff --git a/backend/internal/auth/middleware.go b/backend/internal/auth/middleware.go index 9789b04..f8d6166 100644 --- a/backend/internal/auth/middleware.go +++ b/backend/internal/auth/middleware.go @@ -9,6 +9,7 @@ import ( "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" + "go.uber.org/zap" ) type contextKey string @@ -20,41 +21,40 @@ const ContextUserIDKey = "user_id" type AuthMiddleware struct { store store.Store jwtSecret string + logger *zap.Logger } // NewAuthMiddleware creates a new auth middleware -func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware { +func NewAuthMiddleware(store store.Store, jwtSecret string, logger *zap.Logger) *AuthMiddleware { + if logger == nil { + logger = zap.NewNop() + } return &AuthMiddleware{ store: store, jwtSecret: jwtSecret, + logger: logger, } } // RequireAuth middleware requires valid authentication func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc { return func(c *gin.Context) { - fmt.Println("🔒 RequireAuth: Starting authentication check") - user, claims, err := m.getUserFromToken(c) - fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err) - if claims != nil { - fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email) - } - if err != nil || user == nil { - fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user) + if err != nil { + m.logger.Warn("auth failed", + zap.Error(err), + zap.String("method", c.Request.Method), + zap.String("path", c.FullPath()), + ) + } c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.Abort() return } // Note: Name and Email might be empty for old JWT tokens - if claims.Name == "" || claims.Email == "" { - fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n") - } - - fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user) c.Set(ContextUserIDKey, user) c.Set("user_email", claims.Email) c.Set("user_name", claims.Name) @@ -88,21 +88,17 @@ func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc { // 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error) func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) { authHeader := c.GetHeader("Authorization") - fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader) if authHeader == "" { - fmt.Println("⚠️ getUserFromToken: No Authorization header") return nil, nil, nil } parts := strings.Split(authHeader, " ") if len(parts) != 2 || parts[0] != "Bearer" { - fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0]) return nil, nil, nil } tokenString := parts[1] - fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))]) token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { // 必须要验证签名算法是 HMAC (HS256) @@ -113,7 +109,6 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai }) if err != nil { - fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err) return nil, nil, err } @@ -123,17 +118,14 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai // 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String() userID, err := uuid.Parse(claims.Subject) if err != nil { - fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err) return nil, nil, fmt.Errorf("invalid user ID in token") } // 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email) // 这一步完全没有查数据库,速度极快 - fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email) return &userID, claims, nil } - fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid") return nil, nil, fmt.Errorf("invalid token claims") } @@ -141,8 +133,6 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai func GetUserFromContext(c *gin.Context) *uuid.UUID { // 修正点:使用和存入时完全一样的 Key val, exists := c.Get(ContextUserIDKey) - fmt.Println("within getFromContext the id is ... ") - fmt.Println(val) if !exists { return nil } diff --git a/backend/internal/handlers/document.go b/backend/internal/handlers/document.go index 3069885..28be434 100644 --- a/backend/internal/handlers/document.go +++ b/backend/internal/handlers/document.go @@ -1,22 +1,33 @@ package handlers import ( - "fmt" + "context" "net/http" + "time" "github.com/M1ngdaXie/realtime-collab/internal/auth" + "github.com/M1ngdaXie/realtime-collab/internal/messagebus" "github.com/M1ngdaXie/realtime-collab/internal/models" "github.com/M1ngdaXie/realtime-collab/internal/store" "github.com/gin-gonic/gin" "github.com/google/uuid" + "go.uber.org/zap" ) type DocumentHandler struct { - store *store.PostgresStore + store *store.PostgresStore + messageBus messagebus.MessageBus + serverID string + logger *zap.Logger } -func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler { - return &DocumentHandler{store: s} +func NewDocumentHandler(s *store.PostgresStore, msgBus messagebus.MessageBus, serverID string, logger *zap.Logger) *DocumentHandler { + return &DocumentHandler{ + store: s, + messageBus: msgBus, + serverID: serverID, + logger: logger, + } } // CreateDocument creates a new document (requires auth) @@ -45,8 +56,6 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) { func (h *DocumentHandler) ListDocuments(c *gin.Context) { userID := auth.GetUserFromContext(c) - fmt.Println("Getting userId, which is : ") - fmt.Println(userID) if userID == nil { respondUnauthorized(c, "Authentication required to list documents") return @@ -113,6 +122,13 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) { } userID := auth.GetUserFromContext(c) + shareToken := c.Query("share") + + doc, err := h.store.GetDocument(id) + if err != nil { + respondNotFound(c, "document") + return + } // Check permission if authenticated if userID != nil { @@ -125,12 +141,22 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) { respondForbidden(c, "Access denied") return } - } - - doc, err := h.store.GetDocument(id) - if err != nil { - respondNotFound(c, "document") - return + } else { + // Unauthenticated: require valid share token or public doc + if shareToken != "" { + valid, err := h.store.ValidateShareToken(c.Request.Context(), id, shareToken) + if err != nil { + respondInternalError(c, "Failed to validate share token", err) + return + } + if !valid { + respondForbidden(c, "Invalid or expired share token") + return + } + } else if !doc.Is_Public { + respondForbidden(c, "This document is not public. Please sign in to access.") + return + } } // Return empty byte slice if state is nil (new document) @@ -191,6 +217,16 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) { return } + if streamID, seq, ok := h.addSnapshotMarker(c.Request.Context(), id); ok { + if err := h.store.UpsertStreamCheckpoint(c.Request.Context(), id, streamID, seq); err != nil { + if h.logger != nil { + h.logger.Warn("Failed to upsert stream checkpoint after snapshot", + zap.String("document_id", id.String()), + zap.Error(err)) + } + } + } + c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"}) } @@ -234,6 +270,43 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"}) } +func (h *DocumentHandler) addSnapshotMarker(ctx context.Context, documentID uuid.UUID) (string, int64, bool) { + if h.messageBus == nil { + return "", 0, false + } + + streamKey := "stream:" + documentID.String() + seqKey := "seq:" + documentID.String() + + seq, err := h.messageBus.Incr(ctx, seqKey) + if err != nil { + if h.logger != nil { + h.logger.Warn("Failed to increment snapshot sequence", + zap.String("document_id", documentID.String()), + zap.Error(err)) + } + return "", 0, false + } + + values := map[string]interface{}{ + "type": "snapshot", + "server_id": h.serverID, + "seq": seq, + "timestamp": time.Now().Format(time.RFC3339), + } + + streamID, err := h.messageBus.XAdd(ctx, streamKey, 10000, true, values) + if err != nil { + if h.logger != nil { + h.logger.Warn("Failed to add snapshot marker to stream", + zap.String("stream_key", streamKey), + zap.Error(err)) + } + return "", 0, false + } + return streamID, seq, true +} + // GetDocumentPermission returns the user's permission level for a document func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) { documentID, err := uuid.Parse(c.Param("id")) diff --git a/backend/internal/handlers/document_test.go b/backend/internal/handlers/document_test.go index 229c2f9..2796912 100644 --- a/backend/internal/handlers/document_test.go +++ b/backend/internal/handlers/document_test.go @@ -7,10 +7,12 @@ import ( "testing" "github.com/M1ngdaXie/realtime-collab/internal/auth" + "github.com/M1ngdaXie/realtime-collab/internal/messagebus" "github.com/M1ngdaXie/realtime-collab/internal/models" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/stretchr/testify/suite" + "go.uber.org/zap" ) // DocumentHandlerSuite tests document CRUD operations @@ -23,7 +25,7 @@ type DocumentHandlerSuite struct { // SetupTest runs before each test func (s *DocumentHandlerSuite) SetupTest() { s.BaseHandlerSuite.SetupTest() - s.handler = NewDocumentHandler(s.store) + s.handler = NewDocumentHandler(s.store, messagebus.NewLocalMessageBus(), "test-server", zap.NewNop()) s.setupRouter() } diff --git a/backend/internal/handlers/share_test.go b/backend/internal/handlers/share_test.go index 6607825..b3350ff 100644 --- a/backend/internal/handlers/share_test.go +++ b/backend/internal/handlers/share_test.go @@ -10,6 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/stretchr/testify/suite" + "go.uber.org/zap" ) // ShareHandlerSuite tests for share handler endpoints @@ -24,7 +25,7 @@ func (s *ShareHandlerSuite) SetupTest() { s.BaseHandlerSuite.SetupTest() // Create handler and router - authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret) + authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret, zap.NewNop()) s.handler = NewShareHandler(s.store, s.cfg) s.router = gin.New() diff --git a/backend/internal/handlers/websocket.go b/backend/internal/handlers/websocket.go index 57f5f5c..391d85b 100644 --- a/backend/internal/handlers/websocket.go +++ b/backend/internal/handlers/websocket.go @@ -1,13 +1,18 @@ package handlers import ( + "context" + "encoding/base64" + "fmt" "log" "net/http" + "strconv" "time" "github.com/M1ngdaXie/realtime-collab/internal/auth" "github.com/M1ngdaXie/realtime-collab/internal/config" "github.com/M1ngdaXie/realtime-collab/internal/hub" + "github.com/M1ngdaXie/realtime-collab/internal/messagebus" "github.com/M1ngdaXie/realtime-collab/internal/store" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -19,16 +24,18 @@ import ( var connectionSem = make(chan struct{}, 200) type WebSocketHandler struct { - hub *hub.Hub - store store.Store - cfg *config.Config + hub *hub.Hub + store store.Store + cfg *config.Config + msgBus messagebus.MessageBus } -func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config) *WebSocketHandler { +func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config, msgBus messagebus.MessageBus) *WebSocketHandler { return &WebSocketHandler{ - hub: h, - store: s, - cfg: cfg, + hub: h, + store: s, + cfg: cfg, + msgBus: msgBus, } } @@ -170,6 +177,105 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) { // Start goroutines go client.WritePump() go client.ReadPump() + go wsh.replayBacklog(client, documentID) log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID) } + +const maxReplayUpdates = 5000 + +func (wsh *WebSocketHandler) replayBacklog(client *hub.Client, documentID uuid.UUID) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + checkpoint, err := wsh.store.GetStreamCheckpoint(ctx, documentID) + if err != nil || checkpoint == nil || checkpoint.LastStreamID == "" { + return + } + + streamKey := "stream:" + documentID.String() + var sent int + + // Primary: Redis stream replay + if wsh.msgBus != nil { + messages, err := wsh.msgBus.XRange(ctx, streamKey, checkpoint.LastStreamID, "+") + if err == nil && len(messages) > 0 { + for _, msg := range messages { + if msg.ID == checkpoint.LastStreamID { + continue + } + if sent >= maxReplayUpdates { + log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String()) + return + } + msgType := getString(msg.Values["type"]) + if msgType != "update" { + continue + } + seq := parseInt64(msg.Values["seq"]) + if seq <= checkpoint.LastSeq { + continue + } + payloadB64 := getString(msg.Values["yjs_payload"]) + payload, err := base64.StdEncoding.DecodeString(payloadB64) + if err != nil { + continue + } + if client.Enqueue(payload) { + sent++ + } else { + return + } + } + return + } + } + + // Fallback: DB history replay + updates, err := wsh.store.ListUpdateHistoryAfterSeq(ctx, documentID, checkpoint.LastSeq, maxReplayUpdates) + if err != nil { + return + } + for _, upd := range updates { + if sent >= maxReplayUpdates { + log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String()) + return + } + if client.Enqueue(upd.Payload) { + sent++ + } else { + return + } + } +} + +func getString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case []byte: + return string(v) + default: + return fmt.Sprint(v) + } +} + +func parseInt64(value interface{}) int64 { + switch v := value.(type) { + case int64: + return v + case int: + return int64(v) + case uint64: + return int64(v) + case string: + if parsed, err := strconv.ParseInt(v, 10, 64); err == nil { + return parsed + } + case []byte: + if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil { + return parsed + } + } + return 0 +} diff --git a/backend/internal/hub/hub.go b/backend/internal/hub/hub.go index bb8868e..7cee452 100644 --- a/backend/internal/hub/hub.go +++ b/backend/internal/hub/hub.go @@ -2,6 +2,8 @@ package hub import ( "context" + "encoding/base64" + "strconv" "sync" "time" @@ -37,10 +39,11 @@ type Client struct { idsMu sync.Mutex } type Room struct { - ID string - clients map[*Client]bool - mu sync.RWMutex - cancel context.CancelFunc + ID string + clients map[*Client]bool + mu sync.RWMutex + cancel context.CancelFunc + reconnectCount int // Track Redis reconnection attempts for debugging } type Hub struct { @@ -64,6 +67,10 @@ type Hub struct { // Bounded worker pool for Redis SetAwareness awarenessQueue chan awarenessItem + + // Stream persistence worker pool (P1: Redis Streams durability) + streamQueue chan *Message // buffered queue for XADD operations + streamDone chan struct{} // close to signal stream workers to exit } const ( @@ -79,6 +86,13 @@ const ( // awarenessQueueSize is the buffer size for awareness updates. awarenessQueueSize = 4096 + + // streamWorkerCount is the number of fixed goroutines consuming from streamQueue. + // 50 workers match publish workers for consistent throughput. + streamWorkerCount = 50 + + // streamQueueSize is the buffer size for the stream persistence queue. + streamQueueSize = 4096 ) type awarenessItem struct { @@ -103,11 +117,15 @@ func NewHub(messagebus messagebus.MessageBus, serverID string, logger *zap.Logge publishDone: make(chan struct{}), // bounded awareness worker pool awarenessQueue: make(chan awarenessItem, awarenessQueueSize), + // Stream persistence worker pool + streamQueue: make(chan *Message, streamQueueSize), + streamDone: make(chan struct{}), } // Start the fixed worker pool for Redis publishing h.startPublishWorkers(publishWorkerCount) h.startAwarenessWorkers(awarenessWorkerCount) + h.startStreamWorkers(streamWorkerCount) return h } @@ -173,6 +191,82 @@ func (h *Hub) startAwarenessWorkers(n int) { h.logger.Info("Awareness worker pool started", zap.Int("workers", n)) } +// startStreamWorkers launches n goroutines that consume from streamQueue +// and add messages to Redis Streams for durability and replay. +func (h *Hub) startStreamWorkers(n int) { + for i := 0; i < n; i++ { + go func(workerID int) { + for { + select { + case <-h.streamDone: + h.logger.Info("Stream worker exiting", zap.Int("worker_id", workerID)) + return + case msg, ok := <-h.streamQueue: + if !ok { + return + } + h.addToStream(msg) + } + } + }(i) + } + h.logger.Info("Stream worker pool started", zap.Int("workers", n)) +} + +// encodeBase64 encodes binary data to base64 string for Redis storage +func encodeBase64(data []byte) string { + return base64.StdEncoding.EncodeToString(data) +} + +// addToStream adds a message to Redis Streams for durability +func (h *Hub) addToStream(msg *Message) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + streamKey := "stream:" + msg.RoomID + + // Get next sequence number atomically + seqKey := "seq:" + msg.RoomID + seq, err := h.messagebus.Incr(ctx, seqKey) + if err != nil { + h.logger.Error("Failed to increment sequence", + zap.String("room_id", msg.RoomID), + zap.Error(err)) + return + } + + // Encode payload as base64 (binary-safe storage) + payload := encodeBase64(msg.Data) + + // Extract Yjs message type from first byte as numeric string + msgType := "0" + if len(msg.Data) > 0 { + msgType = strconv.Itoa(int(msg.Data[0])) + } + + // Add entry to Stream with MAXLEN trimming + values := map[string]interface{}{ + "type": "update", + "server_id": h.serverID, + "yjs_payload": payload, + "msg_type": msgType, + "seq": seq, + "timestamp": time.Now().Format(time.RFC3339), + } + + _, err = h.messagebus.XAdd(ctx, streamKey, 10000, true, values) + if err != nil { + h.logger.Error("Failed to add to Stream", + zap.String("stream_key", streamKey), + zap.Int64("seq", seq), + zap.Error(err)) + return + } + + // Mark this document as active so the persist worker only processes active streams + _ = h.messagebus.ZAdd(ctx, "active-streams", float64(time.Now().Unix()), msg.RoomID) +} + func (h *Hub) Run() { for { select { @@ -471,6 +565,7 @@ func (h *Hub) broadcastMessage(message *Message) { // 只有本地客户端发出的消息 (sender != nil) 才推送到 Redis // P0 fix: send to bounded worker pool instead of spawning unbounded goroutines if message.sender != nil && !h.fallbackMode && h.messagebus != nil { + // 3a. Publish to Pub/Sub (real-time cross-server broadcast) select { case h.publishQueue <- message: // Successfully queued for async publish by worker pool @@ -479,6 +574,19 @@ func (h *Hub) broadcastMessage(message *Message) { h.logger.Warn("Publish queue full, dropping Redis publish", zap.String("room_id", message.RoomID)) } + + // 3b. Add to Stream for durability (only Type 0 updates, not Type 1 awareness) + // Type 0 = Yjs sync/update messages (document changes) + // Type 1 = Yjs awareness messages (cursors, presence) - ephemeral, skip + if len(message.Data) > 0 && message.Data[0] == 0 { + select { + case h.streamQueue <- message: + // Successfully queued for async Stream add + default: + h.logger.Warn("Stream queue full, dropping durability", + zap.String("room_id", message.RoomID)) + } + } } } @@ -504,10 +612,28 @@ func (h *Hub) broadcastToLocalClients(room *Room, data []byte, sender *Client) { } } func (h *Hub) startRoomMessageForwarding(ctx context.Context, roomID string, msgChan <-chan []byte) { - h.logger.Info("Starting message forwarding from Redis to room", - zap.String("room_id", roomID), - zap.String("server_id", h.serverID), - ) + // Increment and log reconnection count for debugging + h.mu.RLock() + room, exists := h.rooms[roomID] + h.mu.RUnlock() + + if exists { + room.mu.Lock() + room.reconnectCount++ + reconnectCount := room.reconnectCount + room.mu.Unlock() + + h.logger.Info("Starting message forwarding from Redis to room", + zap.String("room_id", roomID), + zap.String("server_id", h.serverID), + zap.Int("reconnect_count", reconnectCount), + ) + } else { + h.logger.Info("Starting message forwarding from Redis to room", + zap.String("room_id", roomID), + zap.String("server_id", h.serverID), + ) + } for { select { @@ -791,12 +917,28 @@ func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string UserAvatar: userAvatar, Permission: permission, Conn: conn, - send: make(chan []byte, 1024), + send: make(chan []byte, 8192), hub: hub, roomID: roomID, observedYjsIDs: make(map[uint64]uint64), } } + +// Enqueue sends a message to the client send buffer (non-blocking). +// Returns false if the buffer is full. +func (c *Client) Enqueue(message []byte) bool { + select { + case c.send <- message: + return true + default: + if c.hub != nil && c.hub.logger != nil { + c.hub.logger.Warn("Client send buffer full during replay", + zap.String("client_id", c.ID), + zap.String("room_id", c.roomID)) + } + return false + } +} func (c *Client) unregister() { c.unregisterOnce.Do(func() { c.hub.Unregister <- c diff --git a/backend/internal/logger/logger.go b/backend/internal/logger/logger.go index 9de2fdd..a06caab 100644 --- a/backend/internal/logger/logger.go +++ b/backend/internal/logger/logger.go @@ -23,7 +23,7 @@ func NewLogger(isDevelopment bool) (*zap.Logger, error) { // 👇 关键修改:直接拉到 Fatal 级别 // 这样 Error, Warn, Info, Debug 全部都会被忽略 // 彻底消除 IO 锁竞争 - config.Level = zap.NewAtomicLevelAt(zapcore.FatalLevel) + config.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel) logger, err := config.Build() if err != nil { diff --git a/backend/internal/messagebus/interface.go b/backend/internal/messagebus/interface.go index 8b4d025..8016db0 100644 --- a/backend/internal/messagebus/interface.go +++ b/backend/internal/messagebus/interface.go @@ -2,6 +2,7 @@ package messagebus import ( "context" + "time" ) // MessageBus abstracts message distribution across server instances @@ -33,6 +34,72 @@ type MessageBus interface { // Close gracefully shuts down the message bus Close() error + + // ========== Redis Streams Operations ========== + + // XAdd adds a new entry to a stream with optional MAXLEN trimming + XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) + + // XReadGroup reads messages from a stream using a consumer group + XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) + + // XAck acknowledges one or more messages from a consumer group + XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) + + // XGroupCreate creates a new consumer group for a stream + XGroupCreate(ctx context.Context, stream, group, start string) error + + // XGroupCreateMkStream creates a consumer group and the stream if it doesn't exist + XGroupCreateMkStream(ctx context.Context, stream, group, start string) error + + // XPending returns pending messages information for a consumer group + XPending(ctx context.Context, stream, group string) (*PendingInfo, error) + + // XClaim claims pending messages from a consumer group + XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) + + // XAutoClaim claims pending messages automatically (Redis >= 6.2) + // Returns claimed messages and next start ID. + XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) + + // XRange reads a range of messages from a stream + XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) + + // XTrimMinID trims a stream to a minimum ID (time-based retention) + XTrimMinID(ctx context.Context, stream, minID string) (int64, error) + + // Incr increments a counter atomically (for sequence numbers) + Incr(ctx context.Context, key string) (int64, error) + + // ========== Sorted Set (ZSET) Operations ========== + + // ZAdd adds a member with a score to a sorted set (used for active-stream tracking) + ZAdd(ctx context.Context, key string, score float64, member string) error + + // ZRangeByScore returns members with scores between min and max + ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) + + // ZRemRangeByScore removes members with scores between min and max + ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) + + // Distributed lock helpers (used by background workers) + AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) + RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) + ReleaseLock(ctx context.Context, key string) error +} + +// StreamMessage represents a message from a Redis Stream +type StreamMessage struct { + ID string + Values map[string]interface{} +} + +// PendingInfo contains information about pending messages in a consumer group +type PendingInfo struct { + Count int64 + Lower string + Upper string + Consumers map[string]int64 } // LocalMessageBus is a no-op implementation for single-server mode @@ -78,3 +145,73 @@ func (l *LocalMessageBus) IsHealthy() bool { func (l *LocalMessageBus) Close() error { return nil } + +// ========== Redis Streams Operations (No-op for local mode) ========== + +func (l *LocalMessageBus) XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) { + return "0-0", nil +} + +func (l *LocalMessageBus) XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) { + return nil, nil +} + +func (l *LocalMessageBus) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) { + return 0, nil +} + +func (l *LocalMessageBus) XGroupCreate(ctx context.Context, stream, group, start string) error { + return nil +} + +func (l *LocalMessageBus) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error { + return nil +} + +func (l *LocalMessageBus) XPending(ctx context.Context, stream, group string) (*PendingInfo, error) { + return &PendingInfo{}, nil +} + +func (l *LocalMessageBus) XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) { + return nil, nil +} + +func (l *LocalMessageBus) XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) { + return nil, "0-0", nil +} + +func (l *LocalMessageBus) XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) { + return nil, nil +} + +func (l *LocalMessageBus) XTrimMinID(ctx context.Context, stream, minID string) (int64, error) { + return 0, nil +} + +func (l *LocalMessageBus) Incr(ctx context.Context, key string) (int64, error) { + return 0, nil +} + +func (l *LocalMessageBus) ZAdd(ctx context.Context, key string, score float64, member string) error { + return nil +} + +func (l *LocalMessageBus) ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) { + return nil, nil +} + +func (l *LocalMessageBus) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) { + return 0, nil +} + +func (l *LocalMessageBus) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) { + return true, nil +} + +func (l *LocalMessageBus) RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) { + return true, nil +} + +func (l *LocalMessageBus) ReleaseLock(ctx context.Context, key string) error { + return nil +} diff --git a/backend/internal/messagebus/redis.go b/backend/internal/messagebus/redis.go index 477cceb..fa6ef5d 100644 --- a/backend/internal/messagebus/redis.go +++ b/backend/internal/messagebus/redis.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "log" + "net" "strconv" "sync" "time" @@ -88,6 +89,23 @@ func NewRedisMessageBus(redisURL string, serverID string, logger *zap.Logger) (* // - Redis will handle stale connections via TCP keepalive opts.ConnMaxLifetime = 1 * time.Hour + // ================================ + // Socket-Level Timeout Configuration (prevents indefinite hangs) + // ================================ + // Without these, TCP reads/writes block indefinitely when Redis is unresponsive, + // causing OS-level timeouts (60-120s) instead of application-level control. + + // DialTimeout: How long to wait for initial connection establishment + opts.DialTimeout = 5 * time.Second + + // ReadTimeout: Maximum time for socket read operations + // - 30s is appropriate for PubSub (long intervals between messages are normal) + // - Prevents indefinite blocking when Redis hangs + opts.ReadTimeout = 30 * time.Second + + // WriteTimeout: Maximum time for socket write operations + opts.WriteTimeout = 5 * time.Second + client := goredis.NewClient(opts) // ================================ @@ -215,12 +233,15 @@ func (r *RedisMessageBus) readLoop(ctx context.Context, roomID string, sub *subs if ctx.Err() != nil { return } + r.logger.Warn("PubSub initial subscription failed, retrying with backoff", + zap.String("roomID", roomID), + zap.Error(err), + zap.Duration("backoff", backoff), + ) time.Sleep(backoff) - if backoff < maxBackoff { - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } + backoff = backoff * 2 + if backoff > maxBackoff { + backoff = maxBackoff } continue } @@ -242,12 +263,15 @@ func (r *RedisMessageBus) readLoop(ctx context.Context, roomID string, sub *subs if ctx.Err() != nil { return } + r.logger.Warn("PubSub receive failed, retrying with backoff", + zap.String("roomID", roomID), + zap.Error(err), + zap.Duration("backoff", backoff), + ) time.Sleep(backoff) - if backoff < maxBackoff { - backoff *= 2 - if backoff > maxBackoff { - backoff = maxBackoff - } + backoff = backoff * 2 + if backoff > maxBackoff { + backoff = maxBackoff } } } @@ -261,12 +285,15 @@ func (r *RedisMessageBus) receiveOnce(ctx context.Context, roomID string, pubsub msg, err := pubsub.ReceiveTimeout(ctx, 5*time.Second) if err != nil { - if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { - return err + if ctx.Err() != nil { + return ctx.Err() } if errors.Is(err, goredis.Nil) { continue } + if isTimeoutErr(err) { + continue + } r.logger.Warn("pubsub receive error, closing subscription", zap.String("roomID", roomID), zap.Error(err), @@ -308,6 +335,17 @@ func (r *RedisMessageBus) receiveOnce(ctx context.Context, roomID string, pubsub } } +func isTimeoutErr(err error) bool { + if err == nil { + return false + } + if errors.Is(err, context.DeadlineExceeded) { + return true + } + var netErr net.Error + return errors.As(err, &netErr) && netErr.Timeout() +} + // Unsubscribe stops listening to a room func (r *RedisMessageBus) Unsubscribe(ctx context.Context, roomID string) error { r.subMu.Lock() @@ -430,7 +468,7 @@ func (r *RedisMessageBus) DeleteAwareness(ctx context.Context, roomID string, cl // IsHealthy checks Redis connectivity func (r *RedisMessageBus) IsHealthy() bool { - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() // 只有 Ping 成功且没有报错,才认为服务是健康的 @@ -516,3 +554,223 @@ func (r *RedisMessageBus) ClearAllAwareness(ctx context.Context, roomID string) // 直接使用 Del 命令删除整个 Key return r.client.Del(ctx, key).Err() } + +// ========== Redis Streams Operations ========== + +// XAdd adds a new entry to a stream with optional MAXLEN trimming +func (r *RedisMessageBus) XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) { + result := r.client.XAdd(ctx, &goredis.XAddArgs{ + Stream: stream, + MaxLen: maxLen, + Approx: approx, + Values: values, + }) + return result.Val(), result.Err() +} + +// XReadGroup reads messages from a stream using a consumer group +func (r *RedisMessageBus) XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) { + result := r.client.XReadGroup(ctx, &goredis.XReadGroupArgs{ + Group: group, + Consumer: consumer, + Streams: streams, + Count: count, + Block: block, + }) + + if err := result.Err(); err != nil { + // Timeout is not an error, just no new messages + if err == goredis.Nil { + return nil, nil + } + return nil, err + } + + // Convert go-redis XStream to our StreamMessage format + var messages []StreamMessage + for _, stream := range result.Val() { + for _, msg := range stream.Messages { + messages = append(messages, StreamMessage{ + ID: msg.ID, + Values: msg.Values, + }) + } + } + + return messages, nil +} + +// XAck acknowledges one or more messages from a consumer group +func (r *RedisMessageBus) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) { + result := r.client.XAck(ctx, stream, group, ids...) + return result.Val(), result.Err() +} + +// XGroupCreate creates a new consumer group for a stream +func (r *RedisMessageBus) XGroupCreate(ctx context.Context, stream, group, start string) error { + return r.client.XGroupCreate(ctx, stream, group, start).Err() +} + +// XGroupCreateMkStream creates a consumer group and the stream if it doesn't exist +func (r *RedisMessageBus) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error { + return r.client.XGroupCreateMkStream(ctx, stream, group, start).Err() +} + +// XPending returns pending messages information for a consumer group +func (r *RedisMessageBus) XPending(ctx context.Context, stream, group string) (*PendingInfo, error) { + result := r.client.XPending(ctx, stream, group) + if err := result.Err(); err != nil { + return nil, err + } + + pending := result.Val() + consumers := make(map[string]int64) + for name, count := range pending.Consumers { + consumers[name] = count + } + + return &PendingInfo{ + Count: pending.Count, + Lower: pending.Lower, + Upper: pending.Higher, // go-redis uses "Higher" instead of "Upper" + Consumers: consumers, + }, nil +} + +// XClaim claims pending messages from a consumer group +func (r *RedisMessageBus) XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) { + result := r.client.XClaim(ctx, &goredis.XClaimArgs{ + Stream: stream, + Group: group, + Consumer: consumer, + MinIdle: minIdleTime, + Messages: ids, + }) + + if err := result.Err(); err != nil { + return nil, err + } + + // Convert go-redis XMessage to our StreamMessage format + var messages []StreamMessage + for _, msg := range result.Val() { + messages = append(messages, StreamMessage{ + ID: msg.ID, + Values: msg.Values, + }) + } + + return messages, nil +} + +// XAutoClaim claims pending messages automatically (Redis >= 6.2) +func (r *RedisMessageBus) XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) { + result := r.client.XAutoClaim(ctx, &goredis.XAutoClaimArgs{ + Stream: stream, + Group: group, + Consumer: consumer, + MinIdle: minIdleTime, + Start: start, + Count: count, + }) + msgs, nextStart, err := result.Result() + if err != nil { + return nil, "", err + } + messages := make([]StreamMessage, 0, len(msgs)) + for _, msg := range msgs { + messages = append(messages, StreamMessage{ + ID: msg.ID, + Values: msg.Values, + }) + } + return messages, nextStart, nil +} + +// XRange reads a range of messages from a stream +func (r *RedisMessageBus) XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) { + result := r.client.XRange(ctx, stream, start, end) + if err := result.Err(); err != nil { + return nil, err + } + + // Convert go-redis XMessage to our StreamMessage format + var messages []StreamMessage + for _, msg := range result.Val() { + messages = append(messages, StreamMessage{ + ID: msg.ID, + Values: msg.Values, + }) + } + + return messages, nil +} + +// XTrimMinID trims a stream to a minimum ID (time-based retention) +func (r *RedisMessageBus) XTrimMinID(ctx context.Context, stream, minID string) (int64, error) { + // Use XTRIM with MINID and approximation (~) for efficiency + // LIMIT clause prevents blocking Redis during large trims + result := r.client.Do(ctx, "XTRIM", stream, "MINID", "~", minID, "LIMIT", 1000) + if err := result.Err(); err != nil { + return 0, err + } + + // Result is the number of entries removed + trimmed, err := result.Int64() + if err != nil { + return 0, err + } + + return trimmed, nil +} + +// ========== Sorted Set (ZSET) Operations ========== + +// ZAdd adds a member with a score to a sorted set +func (r *RedisMessageBus) ZAdd(ctx context.Context, key string, score float64, member string) error { + return r.client.ZAdd(ctx, key, goredis.Z{Score: score, Member: member}).Err() +} + +// ZRangeByScore returns members with scores between min and max +func (r *RedisMessageBus) ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) { + return r.client.ZRangeByScore(ctx, key, &goredis.ZRangeBy{ + Min: strconv.FormatFloat(min, 'f', -1, 64), + Max: strconv.FormatFloat(max, 'f', -1, 64), + }).Result() +} + +// ZRemRangeByScore removes members with scores between min and max +func (r *RedisMessageBus) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) { + return r.client.ZRemRangeByScore(ctx, key, + strconv.FormatFloat(min, 'f', -1, 64), + strconv.FormatFloat(max, 'f', -1, 64), + ).Result() +} + +// Incr increments a counter atomically (for sequence numbers) +func (r *RedisMessageBus) Incr(ctx context.Context, key string) (int64, error) { + result := r.client.Incr(ctx, key) + return result.Val(), result.Err() +} + +// AcquireLock attempts to acquire a distributed lock with TTL +func (r *RedisMessageBus) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) { + return r.client.SetNX(ctx, key, r.serverID, ttl).Result() +} + +// RefreshLock extends the TTL on an existing lock +func (r *RedisMessageBus) RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) { + result := r.client.SetArgs(ctx, key, r.serverID, goredis.SetArgs{ + Mode: "XX", + TTL: ttl, + }) + if err := result.Err(); err != nil { + return false, err + } + return result.Val() == "OK", nil +} + +// ReleaseLock releases a distributed lock +func (r *RedisMessageBus) ReleaseLock(ctx context.Context, key string) error { + return r.client.Del(ctx, key).Err() +} diff --git a/backend/internal/models/stream_checkpoint.go b/backend/internal/models/stream_checkpoint.go new file mode 100644 index 0000000..3b1e2f4 --- /dev/null +++ b/backend/internal/models/stream_checkpoint.go @@ -0,0 +1,15 @@ +package models + +import ( + "time" + + "github.com/google/uuid" +) + +// StreamCheckpoint tracks the last processed Redis Stream entry per document +type StreamCheckpoint struct { + DocumentID uuid.UUID `json:"document_id"` + LastStreamID string `json:"last_stream_id"` + LastSeq int64 `json:"last_seq"` + UpdatedAt time.Time `json:"updated_at"` +} diff --git a/backend/internal/store/postgres.go b/backend/internal/store/postgres.go index 99da529..5385d21 100644 --- a/backend/internal/store/postgres.go +++ b/backend/internal/store/postgres.go @@ -53,6 +53,15 @@ type Store interface { GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error) GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error) + // Stream checkpoint operations + UpsertStreamCheckpoint(ctx context.Context, documentID uuid.UUID, streamID string, seq int64) error + GetStreamCheckpoint(ctx context.Context, documentID uuid.UUID) (*models.StreamCheckpoint, error) + + // Update history (WAL) operations + InsertUpdateHistoryBatch(ctx context.Context, entries []UpdateHistoryEntry) error + ListUpdateHistoryAfterSeq(ctx context.Context, documentID uuid.UUID, afterSeq int64, limit int) ([]UpdateHistoryEntry, error) + DeleteUpdateHistoryUpToSeq(ctx context.Context, documentID uuid.UUID, maxSeq int64) error + Close() error } diff --git a/backend/internal/store/stream_checkpoint.go b/backend/internal/store/stream_checkpoint.go new file mode 100644 index 0000000..1afff2d --- /dev/null +++ b/backend/internal/store/stream_checkpoint.go @@ -0,0 +1,46 @@ +package store + +import ( + "context" + "fmt" + + "github.com/M1ngdaXie/realtime-collab/internal/models" + "github.com/google/uuid" +) + +// UpsertStreamCheckpoint creates or updates the stream checkpoint for a document +func (s *PostgresStore) UpsertStreamCheckpoint(ctx context.Context, documentID uuid.UUID, streamID string, seq int64) error { + query := ` + INSERT INTO stream_checkpoints (document_id, last_stream_id, last_seq, updated_at) + VALUES ($1, $2, $3, NOW()) + ON CONFLICT (document_id) + DO UPDATE SET last_stream_id = EXCLUDED.last_stream_id, + last_seq = EXCLUDED.last_seq, + updated_at = NOW() + ` + + if _, err := s.db.ExecContext(ctx, query, documentID, streamID, seq); err != nil { + return fmt.Errorf("failed to upsert stream checkpoint: %w", err) + } + return nil +} + +// GetStreamCheckpoint retrieves the stream checkpoint for a document +func (s *PostgresStore) GetStreamCheckpoint(ctx context.Context, documentID uuid.UUID) (*models.StreamCheckpoint, error) { + query := ` + SELECT document_id, last_stream_id, last_seq, updated_at + FROM stream_checkpoints + WHERE document_id = $1 + ` + + var checkpoint models.StreamCheckpoint + if err := s.db.QueryRowContext(ctx, query, documentID).Scan( + &checkpoint.DocumentID, + &checkpoint.LastStreamID, + &checkpoint.LastSeq, + &checkpoint.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("failed to get stream checkpoint: %w", err) + } + return &checkpoint, nil +} diff --git a/backend/internal/store/testutil.go b/backend/internal/store/testutil.go index eb1cb83..7616383 100644 --- a/backend/internal/store/testutil.go +++ b/backend/internal/store/testutil.go @@ -71,10 +71,14 @@ func SetupTestDB(t *testing.T) (*PostgresStore, func()) { // Run migrations scriptsDir := filepath.Join("..", "..", "scripts") migrations := []string{ - "init.sql", - "001_add_users_and_sessions.sql", - "002_add_document_shares.sql", - "003_add_public_sharing.sql", + "000_extensions.sql", + "001_init_schema.sql", + "002_add_users_and_sessions.sql", + "003_add_document_shares.sql", + "004_add_public_sharing.sql", + "005_add_share_link_permission.sql", + "010_add_stream_checkpoints.sql", + "011_add_update_history.sql", } for _, migration := range migrations { @@ -107,6 +111,8 @@ func SetupTestDB(t *testing.T) (*PostgresStore, func()) { func TruncateAllTables(ctx context.Context, store *PostgresStore) error { tables := []string{ "document_updates", + "document_update_history", + "stream_checkpoints", "document_shares", "sessions", "documents", diff --git a/backend/internal/store/update_history.go b/backend/internal/store/update_history.go new file mode 100644 index 0000000..843ce98 --- /dev/null +++ b/backend/internal/store/update_history.go @@ -0,0 +1,115 @@ +package store + +import ( + "context" + "fmt" + "strings" + "time" + "unicode/utf8" + + "github.com/google/uuid" +) + +// UpdateHistoryEntry represents a persisted update from Redis Streams +// used for recovery and replay. +type UpdateHistoryEntry struct { + DocumentID uuid.UUID + StreamID string + Seq int64 + Payload []byte + MsgType string + ServerID string + CreatedAt time.Time +} + +// InsertUpdateHistoryBatch inserts update history entries in a single batch. +// Uses ON CONFLICT DO NOTHING to make inserts idempotent. +func (s *PostgresStore) InsertUpdateHistoryBatch(ctx context.Context, entries []UpdateHistoryEntry) error { + if len(entries) == 0 { + return nil + } + + var sb strings.Builder + sb.WriteString("INSERT INTO document_update_history (document_id, stream_id, seq, payload, msg_type, server_id, created_at) VALUES ") + + args := make([]interface{}, 0, len(entries)*7) + for i, e := range entries { + if i > 0 { + sb.WriteString(",") + } + base := i*7 + 1 + sb.WriteString(fmt.Sprintf("($%d,$%d,$%d,$%d,$%d,$%d,$%d)", base, base+1, base+2, base+3, base+4, base+5, base+6)) + msgType := sanitizeTextForDB(e.MsgType) + serverID := sanitizeTextForDB(e.ServerID) + args = append(args, e.DocumentID, e.StreamID, e.Seq, e.Payload, nullIfEmpty(msgType), nullIfEmpty(serverID), e.CreatedAt) + } + // Idempotent insert + sb.WriteString(" ON CONFLICT (document_id, stream_id) DO NOTHING") + + if _, err := s.db.ExecContext(ctx, sb.String(), args...); err != nil { + return fmt.Errorf("failed to insert update history batch: %w", err) + } + return nil +} + +// ListUpdateHistoryAfterSeq returns updates with seq greater than afterSeq, ordered by seq. +func (s *PostgresStore) ListUpdateHistoryAfterSeq(ctx context.Context, documentID uuid.UUID, afterSeq int64, limit int) ([]UpdateHistoryEntry, error) { + if limit <= 0 { + limit = 1000 + } + query := ` + SELECT document_id, stream_id, seq, payload, COALESCE(msg_type, ''), COALESCE(server_id, ''), created_at + FROM document_update_history + WHERE document_id = $1 AND seq > $2 + ORDER BY seq ASC + LIMIT $3 + ` + + rows, err := s.db.QueryContext(ctx, query, documentID, afterSeq, limit) + if err != nil { + return nil, fmt.Errorf("failed to list update history: %w", err) + } + defer rows.Close() + + var results []UpdateHistoryEntry + for rows.Next() { + var e UpdateHistoryEntry + if err := rows.Scan(&e.DocumentID, &e.StreamID, &e.Seq, &e.Payload, &e.MsgType, &e.ServerID, &e.CreatedAt); err != nil { + return nil, fmt.Errorf("failed to scan update history: %w", err) + } + results = append(results, e) + } + return results, nil +} + +// DeleteUpdateHistoryUpToSeq deletes updates with seq <= maxSeq for a document. +func (s *PostgresStore) DeleteUpdateHistoryUpToSeq(ctx context.Context, documentID uuid.UUID, maxSeq int64) error { + query := ` + DELETE FROM document_update_history + WHERE document_id = $1 AND seq <= $2 + ` + if _, err := s.db.ExecContext(ctx, query, documentID, maxSeq); err != nil { + return fmt.Errorf("failed to delete update history: %w", err) + } + return nil +} + +func nullIfEmpty(s string) interface{} { + if s == "" { + return nil + } + return s +} + +func sanitizeTextForDB(s string) string { + if s == "" { + return "" + } + if strings.IndexByte(s, 0) >= 0 { + return "" + } + if !utf8.ValidString(s) { + return "" + } + return s +} diff --git a/backend/internal/workers/update_persist_worker.go b/backend/internal/workers/update_persist_worker.go new file mode 100644 index 0000000..d20b7d6 --- /dev/null +++ b/backend/internal/workers/update_persist_worker.go @@ -0,0 +1,320 @@ +package workers + +import ( + "context" + "encoding/base64" + "fmt" + "runtime/debug" + "strconv" + "strings" + "time" + "unicode/utf8" + + "github.com/M1ngdaXie/realtime-collab/internal/messagebus" + "github.com/M1ngdaXie/realtime-collab/internal/store" + "github.com/google/uuid" + "go.uber.org/zap" +) + +const ( + updatePersistGroupName = "update-persist-worker" + updatePersistLockKey = "lock:update-persist-worker" + updatePersistLockTTL = 30 * time.Second + updatePersistTick = 2 * time.Second + updateReadCount = 200 + updateReadBlock = -1 // negative → go-redis omits BLOCK clause → non-blocking + updateBatchSize = 500 + updateSafeSeqLag = int64(1000) + updateAutoClaimIdle = 30 * time.Second + updateHeartbeatEvery = 30 * time.Second +) + +// StartUpdatePersistWorker persists Redis Stream updates into Postgres for recovery. +func StartUpdatePersistWorker(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) { + if msgBus == nil || dbStore == nil { + return + } + + for { + func() { + defer func() { + if r := recover(); r != nil { + logWorker(logger, "Update persist worker panic", + zap.Any("panic", r), + zap.ByteString("stack", debug.Stack())) + } + }() + + select { + case <-ctx.Done(): + return + default: + } + + acquired, err := msgBus.AcquireLock(ctx, updatePersistLockKey, updatePersistLockTTL) + if err != nil { + logWorker(logger, "Failed to acquire update persist worker lock", zap.Error(err)) + time.Sleep(updatePersistTick) + return + } + if !acquired { + time.Sleep(updatePersistTick) + return + } + + logWorker(logger, "Update persist worker lock acquired", zap.String("server_id", serverID)) + runUpdatePersistWorker(ctx, msgBus, dbStore, logger, serverID) + }() + + select { + case <-ctx.Done(): + return + default: + } + // If the worker exited (including panic), pause briefly before retry. + time.Sleep(updatePersistTick) + } +} + +func runUpdatePersistWorker(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) { + ticker := time.NewTicker(updatePersistTick) + defer ticker.Stop() + + refreshTicker := time.NewTicker(updatePersistLockTTL / 2) + defer refreshTicker.Stop() + + heartbeatTicker := time.NewTicker(updateHeartbeatEvery) + defer heartbeatTicker.Stop() + + for { + select { + case <-ctx.Done(): + _ = msgBus.ReleaseLock(ctx, updatePersistLockKey) + return + case <-refreshTicker.C: + ok, err := msgBus.RefreshLock(ctx, updatePersistLockKey, updatePersistLockTTL) + if err != nil || !ok { + logWorker(logger, "Update persist worker lock lost", zap.Error(err)) + _ = msgBus.ReleaseLock(ctx, updatePersistLockKey) + return + } + case <-heartbeatTicker.C: + logWorker(logger, "Update persist worker heartbeat", zap.String("server_id", serverID)) + case <-ticker.C: + if err := processUpdatePersistence(ctx, msgBus, dbStore, logger, serverID); err != nil { + logWorker(logger, "Update persist worker tick failed", zap.Error(err)) + } + } + } +} + +func processUpdatePersistence(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) error { + // Only process documents with recent stream activity (active in the last 60 seconds) + cutoff := float64(time.Now().Add(-60 * time.Second).Unix()) + activeDocIDs, err := msgBus.ZRangeByScore(ctx, "active-streams", cutoff, float64(time.Now().Unix())) + if err != nil { + return fmt.Errorf("failed to get active streams: %w", err) + } + + // Prune stale entries older than 5 minutes (best-effort cleanup) + stale := float64(time.Now().Add(-5 * time.Minute).Unix()) + if _, err := msgBus.ZRemRangeByScore(ctx, "active-streams", 0, stale); err != nil { + logWorker(logger, "Failed to prune stale active-streams entries", zap.Error(err)) + } + + for _, docIDStr := range activeDocIDs { + docID, err := uuid.Parse(docIDStr) + if err != nil { + logWorker(logger, "Invalid document ID in active-streams", zap.String("doc_id", docIDStr)) + continue + } + + streamKey := "stream:" + docIDStr + if err := ensureConsumerGroup(ctx, msgBus, streamKey, updatePersistGroupName); err != nil { + logWorker(logger, "Failed to ensure update persist consumer group", zap.String("stream", streamKey), zap.Error(err)) + continue + } + + var ackIDs []string + docEntries := make([]store.UpdateHistoryEntry, 0, updateBatchSize) + + // First, try to claim idle pending messages (e.g., from previous crashes) + claimed, _, err := msgBus.XAutoClaim(ctx, streamKey, updatePersistGroupName, serverID, updateAutoClaimIdle, "0-0", updateReadCount) + if err != nil { + logWorker(logger, "XAutoClaim failed", zap.String("stream", streamKey), zap.Error(err)) + } else if len(claimed) > 0 { + collectStreamMessages(ctx, msgBus, dbStore, logger, docID, streamKey, claimed, &docEntries, &ackIDs) + } + + messages, err := msgBus.XReadGroup(ctx, updatePersistGroupName, serverID, []string{streamKey, ">"}, updateReadCount, updateReadBlock) + if err != nil { + logWorker(logger, "XReadGroup failed", zap.String("stream", streamKey), zap.Error(err)) + continue + } + if len(messages) > 0 { + collectStreamMessages(ctx, msgBus, dbStore, logger, docID, streamKey, messages, &docEntries, &ackIDs) + } + + if len(docEntries) > 0 { + if err := dbStore.InsertUpdateHistoryBatch(ctx, docEntries); err != nil { + logWorker(logger, "Failed to insert update history batch", zap.Error(err)) + // Skip ACK to retry on next tick + continue + } + } + + if len(ackIDs) > 0 { + if _, err := msgBus.XAck(ctx, streamKey, updatePersistGroupName, ackIDs...); err != nil { + logWorker(logger, "XAck failed", zap.String("stream", streamKey), zap.Error(err)) + } + } + } + + return nil +} + +func collectStreamMessages(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, documentID uuid.UUID, streamKey string, messages []messagebus.StreamMessage, docEntries *[]store.UpdateHistoryEntry, ackIDs *[]string) { + for _, msg := range messages { + msgType := getString(msg.Values["type"]) + switch msgType { + case "update": + payloadB64 := getString(msg.Values["yjs_payload"]) + payload, err := base64.StdEncoding.DecodeString(payloadB64) + if err != nil { + logWorker(logger, "Failed to decode update payload", + zap.String("stream", streamKey), + zap.String("stream_id", msg.ID), + zap.Error(err)) + continue + } + seq := parseInt64(msg.Values["seq"]) + msgType := normalizeMsgType(msg.Values["msg_type"]) + serverID := sanitizeText(getString(msg.Values["server_id"])) + entry := store.UpdateHistoryEntry{ + DocumentID: documentID, + StreamID: msg.ID, + Seq: seq, + Payload: payload, + MsgType: msgType, + ServerID: serverID, + CreatedAt: time.Now().UTC(), + } + *docEntries = append(*docEntries, entry) + case "snapshot": + seq := parseInt64(msg.Values["seq"]) + if seq > 0 { + if err := dbStore.UpsertStreamCheckpoint(ctx, documentID, msg.ID, seq); err != nil { + logWorker(logger, "Failed to upsert stream checkpoint from snapshot marker", + zap.String("document_id", documentID.String()), + zap.Error(err)) + } + // Retention: prune DB history based on checkpoint (best-effort) + maxSeq := seq - updateSafeSeqLag + if maxSeq > 0 { + if err := dbStore.DeleteUpdateHistoryUpToSeq(ctx, documentID, maxSeq); err != nil { + logWorker(logger, "Failed to prune update history", + zap.String("document_id", documentID.String()), + zap.Error(err)) + } + } + // Trim Redis stream to avoid unbounded growth (best-effort) + if _, err := msgBus.XTrimMinID(ctx, streamKey, msg.ID); err != nil { + logWorker(logger, "Failed to trim Redis stream", + zap.String("stream", streamKey), + zap.Error(err)) + } + } + } + *ackIDs = append(*ackIDs, msg.ID) + } +} + +func ensureConsumerGroup(ctx context.Context, msgBus messagebus.MessageBus, streamKey, group string) error { + if err := msgBus.XGroupCreateMkStream(ctx, streamKey, group, "0-0"); err != nil { + if !isBusyGroup(err) { + return err + } + } + return nil +} + +func isBusyGroup(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "BUSYGROUP") +} + +func getString(value interface{}) string { + switch v := value.(type) { + case string: + return v + case []byte: + return string(v) + default: + return fmt.Sprint(v) + } +} + +func parseInt64(value interface{}) int64 { + switch v := value.(type) { + case int64: + return v + case int: + return int64(v) + case uint64: + return int64(v) + case string: + if parsed, err := strconv.ParseInt(v, 10, 64); err == nil { + return parsed + } + case []byte: + if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil { + return parsed + } + } + return 0 +} + +func sanitizeText(s string) string { + if s == "" { + return s + } + if strings.IndexByte(s, 0) >= 0 { + return "" + } + if !utf8.ValidString(s) { + return "" + } + return s +} + +func normalizeMsgType(value interface{}) string { + switch v := value.(type) { + case string: + if v == "" { + return "" + } + if len(v) == 1 { + return strconv.Itoa(int(v[0])) + } + return sanitizeText(v) + case []byte: + if len(v) == 0 { + return "" + } + if len(v) == 1 { + return strconv.Itoa(int(v[0])) + } + return sanitizeText(string(v)) + default: + return sanitizeText(fmt.Sprint(v)) + } +} + +func logWorker(logger *zap.Logger, msg string, fields ...zap.Field) { + if logger == nil { + return + } + logger.Info(msg, fields...) +} diff --git a/backend/scripts/010_add_stream_checkpoints.sql b/backend/scripts/010_add_stream_checkpoints.sql new file mode 100644 index 0000000..dbbd543 --- /dev/null +++ b/backend/scripts/010_add_stream_checkpoints.sql @@ -0,0 +1,12 @@ +-- Migration: Add stream checkpoints table for Redis Streams durability +-- This table tracks last processed stream position per document + +CREATE TABLE IF NOT EXISTS stream_checkpoints ( + document_id UUID PRIMARY KEY REFERENCES documents(id) ON DELETE CASCADE, + last_stream_id TEXT NOT NULL, + last_seq BIGINT NOT NULL DEFAULT 0, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_stream_checkpoints_updated_at + ON stream_checkpoints(updated_at DESC); diff --git a/backend/scripts/011_add_update_history.sql b/backend/scripts/011_add_update_history.sql new file mode 100644 index 0000000..2c50f9e --- /dev/null +++ b/backend/scripts/011_add_update_history.sql @@ -0,0 +1,22 @@ +-- Migration: Add update history table for Redis Stream WAL +-- This table stores per-update payloads for recovery and replay + +CREATE TABLE IF NOT EXISTS document_update_history ( + id BIGSERIAL PRIMARY KEY, + document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE, + stream_id TEXT NOT NULL, + seq BIGINT NOT NULL, + payload BYTEA NOT NULL, + msg_type TEXT, + server_id TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_stream_id + ON document_update_history(document_id, stream_id); + +CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_seq + ON document_update_history(document_id, seq); + +CREATE INDEX IF NOT EXISTS idx_update_history_document_seq + ON document_update_history(document_id, seq); diff --git a/docker-compose.yml b/docker-compose.yml index 9ce7c3f..3eefd5b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -24,8 +24,11 @@ services: redis: image: redis:7-alpine container_name: realtime-collab-redis + command: ["redis-server", "--appendonly", "yes"] ports: - "6379:6379" + volumes: + - redis_data:/data healthcheck: test: ["CMD", "redis-cli", "ping"] interval: 10s @@ -34,3 +37,4 @@ services: volumes: postgres_data: + redis_data: diff --git a/frontend/src/api/document.ts b/frontend/src/api/document.ts index c1597ba..8504040 100644 --- a/frontend/src/api/document.ts +++ b/frontend/src/api/document.ts @@ -53,8 +53,11 @@ export const documentsApi = { }, // Get document Yjs state - getState: async (id: string): Promise => { - const response = await authFetch(`${API_BASE_URL}/documents/${id}/state`); + getState: async (id: string, shareToken?: string): Promise => { + const url = shareToken + ? `${API_BASE_URL}/documents/${id}/state?share=${shareToken}` + : `${API_BASE_URL}/documents/${id}/state`; + const response = await authFetch(url); if (!response.ok) throw new Error("Failed to fetch document state"); const arrayBuffer = await response.arrayBuffer(); return new Uint8Array(arrayBuffer); @@ -167,4 +170,4 @@ export const versionsApi = { if (!response.ok) throw new Error('Failed to restore version'); return response.json(); }, -}; \ No newline at end of file +}; diff --git a/frontend/src/hooks/useYjsDocument.ts b/frontend/src/hooks/useYjsDocument.ts index d2a2584..7845977 100644 --- a/frontend/src/hooks/useYjsDocument.ts +++ b/frontend/src/hooks/useYjsDocument.ts @@ -157,10 +157,34 @@ export const useYjsDocument = (documentId: string, shareToken?: string) => { setSynced(true); }); + // Connection stability monitoring with reconnection limits + let reconnectCount = 0; + const maxReconnects = 10; + yjsProviders.websocketProvider.on( "status", (event: { status: string }) => { console.log("WebSocket status:", event.status); + + if (event.status === "disconnected") { + reconnectCount++; + if (reconnectCount >= maxReconnects) { + console.error( + "Max reconnection attempts reached. Please refresh the page." + ); + // Could optionally show a user notification here + } else { + console.log( + `Reconnection attempt ${reconnectCount}/${maxReconnects}` + ); + } + } else if (event.status === "connected") { + // Reset counter on successful connection + if (reconnectCount > 0) { + console.log("Reconnected successfully, resetting counter"); + } + reconnectCount = 0; + } } ); diff --git a/frontend/src/lib/yjs.ts b/frontend/src/lib/yjs.ts index 241a1f8..ac25b97 100644 --- a/frontend/src/lib/yjs.ts +++ b/frontend/src/lib/yjs.ts @@ -30,7 +30,7 @@ export const createYjsDocument = async ( // Load initial state from database BEFORE connecting providers try { - const state = await documentsApi.getState(documentId); + const state = await documentsApi.getState(documentId, shareToken); if (state && state.length > 0) { Y.applyUpdate(ydoc, state); console.log('✓ Loaded document state from database'); @@ -51,7 +51,10 @@ export const createYjsDocument = async ( wsUrl, documentId, ydoc, - { params: wsParams } + { + params: wsParams, + maxBackoffTime: 10000, // Max 10s between reconnect attempts + } ); // Awareness for cursors and presence