feat: implement Redis Streams support with stream checkpoints and update history

- Added Redis Streams operations to the message bus interface and implementation.
- Introduced StreamCheckpoint model to track last processed stream entry per document.
- Implemented UpsertStreamCheckpoint and GetStreamCheckpoint methods in the Postgres store.
- Created document_update_history table for storing update payloads for recovery and replay.
- Developed update persist worker to handle Redis Stream updates and persist them to Postgres.
- Enhanced Docker Compose configuration for Redis with persistence.
- Updated frontend API to support fetching document state with optional share token.
- Added connection stability monitoring in the Yjs document hook.
This commit is contained in:
M1ngdaXie
2026-03-08 17:13:42 -07:00
parent f319e8ec75
commit 50822600ad
22 changed files with 1371 additions and 78 deletions

7
.gitignore vendored
View File

@@ -34,4 +34,9 @@ build/
# Docker volumes and data # Docker volumes and data
postgres_data/ postgres_data/
.claude/ .claude/
#test folder profiles
loadtest/pprof
/docs

View File

@@ -9,6 +9,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap"
) )
type contextKey string type contextKey string
@@ -20,41 +21,40 @@ const ContextUserIDKey = "user_id"
type AuthMiddleware struct { type AuthMiddleware struct {
store store.Store store store.Store
jwtSecret string jwtSecret string
logger *zap.Logger
} }
// NewAuthMiddleware creates a new auth middleware // NewAuthMiddleware creates a new auth middleware
func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware { func NewAuthMiddleware(store store.Store, jwtSecret string, logger *zap.Logger) *AuthMiddleware {
if logger == nil {
logger = zap.NewNop()
}
return &AuthMiddleware{ return &AuthMiddleware{
store: store, store: store,
jwtSecret: jwtSecret, jwtSecret: jwtSecret,
logger: logger,
} }
} }
// RequireAuth middleware requires valid authentication // RequireAuth middleware requires valid authentication
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc { func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
fmt.Println("🔒 RequireAuth: Starting authentication check")
user, claims, err := m.getUserFromToken(c) user, claims, err := m.getUserFromToken(c)
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
if claims != nil {
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
}
if err != nil || user == nil { if err != nil || user == nil {
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user) if err != nil {
m.logger.Warn("auth failed",
zap.Error(err),
zap.String("method", c.Request.Method),
zap.String("path", c.FullPath()),
)
}
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"}) c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort() c.Abort()
return return
} }
// Note: Name and Email might be empty for old JWT tokens // Note: Name and Email might be empty for old JWT tokens
if claims.Name == "" || claims.Email == "" {
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
}
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
c.Set(ContextUserIDKey, user) c.Set(ContextUserIDKey, user)
c.Set("user_email", claims.Email) c.Set("user_email", claims.Email)
c.Set("user_name", claims.Name) c.Set("user_name", claims.Name)
@@ -88,21 +88,17 @@ func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error) // 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) { func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
authHeader := c.GetHeader("Authorization") authHeader := c.GetHeader("Authorization")
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
if authHeader == "" { if authHeader == "" {
fmt.Println("⚠️ getUserFromToken: No Authorization header")
return nil, nil, nil return nil, nil, nil
} }
parts := strings.Split(authHeader, " ") parts := strings.Split(authHeader, " ")
if len(parts) != 2 || parts[0] != "Bearer" { if len(parts) != 2 || parts[0] != "Bearer" {
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
return nil, nil, nil return nil, nil, nil
} }
tokenString := parts[1] tokenString := parts[1]
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
// 必须要验证签名算法是 HMAC (HS256) // 必须要验证签名算法是 HMAC (HS256)
@@ -113,7 +109,6 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
}) })
if err != nil { if err != nil {
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
return nil, nil, err return nil, nil, err
} }
@@ -123,17 +118,14 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String() // 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
userID, err := uuid.Parse(claims.Subject) userID, err := uuid.Parse(claims.Subject)
if err != nil { if err != nil {
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
return nil, nil, fmt.Errorf("invalid user ID in token") return nil, nil, fmt.Errorf("invalid user ID in token")
} }
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email) // 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
// 这一步完全没有查数据库,速度极快 // 这一步完全没有查数据库,速度极快
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
return &userID, claims, nil return &userID, claims, nil
} }
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
return nil, nil, fmt.Errorf("invalid token claims") return nil, nil, fmt.Errorf("invalid token claims")
} }
@@ -141,8 +133,6 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
func GetUserFromContext(c *gin.Context) *uuid.UUID { func GetUserFromContext(c *gin.Context) *uuid.UUID {
// 修正点:使用和存入时完全一样的 Key // 修正点:使用和存入时完全一样的 Key
val, exists := c.Get(ContextUserIDKey) val, exists := c.Get(ContextUserIDKey)
fmt.Println("within getFromContext the id is ... ")
fmt.Println(val)
if !exists { if !exists {
return nil return nil
} }

View File

@@ -1,22 +1,33 @@
package handlers package handlers
import ( import (
"fmt" "context"
"net/http" "net/http"
"time"
"github.com/M1ngdaXie/realtime-collab/internal/auth" "github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
"github.com/M1ngdaXie/realtime-collab/internal/models" "github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/M1ngdaXie/realtime-collab/internal/store" "github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"go.uber.org/zap"
) )
type DocumentHandler struct { type DocumentHandler struct {
store *store.PostgresStore store *store.PostgresStore
messageBus messagebus.MessageBus
serverID string
logger *zap.Logger
} }
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler { func NewDocumentHandler(s *store.PostgresStore, msgBus messagebus.MessageBus, serverID string, logger *zap.Logger) *DocumentHandler {
return &DocumentHandler{store: s} return &DocumentHandler{
store: s,
messageBus: msgBus,
serverID: serverID,
logger: logger,
}
} }
// CreateDocument creates a new document (requires auth) // CreateDocument creates a new document (requires auth)
@@ -45,8 +56,6 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
func (h *DocumentHandler) ListDocuments(c *gin.Context) { func (h *DocumentHandler) ListDocuments(c *gin.Context) {
userID := auth.GetUserFromContext(c) userID := auth.GetUserFromContext(c)
fmt.Println("Getting userId, which is : ")
fmt.Println(userID)
if userID == nil { if userID == nil {
respondUnauthorized(c, "Authentication required to list documents") respondUnauthorized(c, "Authentication required to list documents")
return return
@@ -113,6 +122,13 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
} }
userID := auth.GetUserFromContext(c) userID := auth.GetUserFromContext(c)
shareToken := c.Query("share")
doc, err := h.store.GetDocument(id)
if err != nil {
respondNotFound(c, "document")
return
}
// Check permission if authenticated // Check permission if authenticated
if userID != nil { if userID != nil {
@@ -125,12 +141,22 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
respondForbidden(c, "Access denied") respondForbidden(c, "Access denied")
return return
} }
} } else {
// Unauthenticated: require valid share token or public doc
doc, err := h.store.GetDocument(id) if shareToken != "" {
if err != nil { valid, err := h.store.ValidateShareToken(c.Request.Context(), id, shareToken)
respondNotFound(c, "document") if err != nil {
return respondInternalError(c, "Failed to validate share token", err)
return
}
if !valid {
respondForbidden(c, "Invalid or expired share token")
return
}
} else if !doc.Is_Public {
respondForbidden(c, "This document is not public. Please sign in to access.")
return
}
} }
// Return empty byte slice if state is nil (new document) // Return empty byte slice if state is nil (new document)
@@ -191,6 +217,16 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
return return
} }
if streamID, seq, ok := h.addSnapshotMarker(c.Request.Context(), id); ok {
if err := h.store.UpsertStreamCheckpoint(c.Request.Context(), id, streamID, seq); err != nil {
if h.logger != nil {
h.logger.Warn("Failed to upsert stream checkpoint after snapshot",
zap.String("document_id", id.String()),
zap.Error(err))
}
}
}
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"}) c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
} }
@@ -234,6 +270,43 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"}) c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"})
} }
func (h *DocumentHandler) addSnapshotMarker(ctx context.Context, documentID uuid.UUID) (string, int64, bool) {
if h.messageBus == nil {
return "", 0, false
}
streamKey := "stream:" + documentID.String()
seqKey := "seq:" + documentID.String()
seq, err := h.messageBus.Incr(ctx, seqKey)
if err != nil {
if h.logger != nil {
h.logger.Warn("Failed to increment snapshot sequence",
zap.String("document_id", documentID.String()),
zap.Error(err))
}
return "", 0, false
}
values := map[string]interface{}{
"type": "snapshot",
"server_id": h.serverID,
"seq": seq,
"timestamp": time.Now().Format(time.RFC3339),
}
streamID, err := h.messageBus.XAdd(ctx, streamKey, 10000, true, values)
if err != nil {
if h.logger != nil {
h.logger.Warn("Failed to add snapshot marker to stream",
zap.String("stream_key", streamKey),
zap.Error(err))
}
return "", 0, false
}
return streamID, seq, true
}
// GetDocumentPermission returns the user's permission level for a document // GetDocumentPermission returns the user's permission level for a document
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) { func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id")) documentID, err := uuid.Parse(c.Param("id"))

View File

@@ -7,10 +7,12 @@ import (
"testing" "testing"
"github.com/M1ngdaXie/realtime-collab/internal/auth" "github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
"github.com/M1ngdaXie/realtime-collab/internal/models" "github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"go.uber.org/zap"
) )
// DocumentHandlerSuite tests document CRUD operations // DocumentHandlerSuite tests document CRUD operations
@@ -23,7 +25,7 @@ type DocumentHandlerSuite struct {
// SetupTest runs before each test // SetupTest runs before each test
func (s *DocumentHandlerSuite) SetupTest() { func (s *DocumentHandlerSuite) SetupTest() {
s.BaseHandlerSuite.SetupTest() s.BaseHandlerSuite.SetupTest()
s.handler = NewDocumentHandler(s.store) s.handler = NewDocumentHandler(s.store, messagebus.NewLocalMessageBus(), "test-server", zap.NewNop())
s.setupRouter() s.setupRouter()
} }

View File

@@ -10,6 +10,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"go.uber.org/zap"
) )
// ShareHandlerSuite tests for share handler endpoints // ShareHandlerSuite tests for share handler endpoints
@@ -24,7 +25,7 @@ func (s *ShareHandlerSuite) SetupTest() {
s.BaseHandlerSuite.SetupTest() s.BaseHandlerSuite.SetupTest()
// Create handler and router // Create handler and router
authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret) authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret, zap.NewNop())
s.handler = NewShareHandler(s.store, s.cfg) s.handler = NewShareHandler(s.store, s.cfg)
s.router = gin.New() s.router = gin.New()

View File

@@ -1,13 +1,18 @@
package handlers package handlers
import ( import (
"context"
"encoding/base64"
"fmt"
"log" "log"
"net/http" "net/http"
"strconv"
"time" "time"
"github.com/M1ngdaXie/realtime-collab/internal/auth" "github.com/M1ngdaXie/realtime-collab/internal/auth"
"github.com/M1ngdaXie/realtime-collab/internal/config" "github.com/M1ngdaXie/realtime-collab/internal/config"
"github.com/M1ngdaXie/realtime-collab/internal/hub" "github.com/M1ngdaXie/realtime-collab/internal/hub"
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
"github.com/M1ngdaXie/realtime-collab/internal/store" "github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/uuid" "github.com/google/uuid"
@@ -19,16 +24,18 @@ import (
var connectionSem = make(chan struct{}, 200) var connectionSem = make(chan struct{}, 200)
type WebSocketHandler struct { type WebSocketHandler struct {
hub *hub.Hub hub *hub.Hub
store store.Store store store.Store
cfg *config.Config cfg *config.Config
msgBus messagebus.MessageBus
} }
func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config) *WebSocketHandler { func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config, msgBus messagebus.MessageBus) *WebSocketHandler {
return &WebSocketHandler{ return &WebSocketHandler{
hub: h, hub: h,
store: s, store: s,
cfg: cfg, cfg: cfg,
msgBus: msgBus,
} }
} }
@@ -170,6 +177,105 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
// Start goroutines // Start goroutines
go client.WritePump() go client.WritePump()
go client.ReadPump() go client.ReadPump()
go wsh.replayBacklog(client, documentID)
log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID) log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID)
} }
const maxReplayUpdates = 5000
func (wsh *WebSocketHandler) replayBacklog(client *hub.Client, documentID uuid.UUID) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
checkpoint, err := wsh.store.GetStreamCheckpoint(ctx, documentID)
if err != nil || checkpoint == nil || checkpoint.LastStreamID == "" {
return
}
streamKey := "stream:" + documentID.String()
var sent int
// Primary: Redis stream replay
if wsh.msgBus != nil {
messages, err := wsh.msgBus.XRange(ctx, streamKey, checkpoint.LastStreamID, "+")
if err == nil && len(messages) > 0 {
for _, msg := range messages {
if msg.ID == checkpoint.LastStreamID {
continue
}
if sent >= maxReplayUpdates {
log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String())
return
}
msgType := getString(msg.Values["type"])
if msgType != "update" {
continue
}
seq := parseInt64(msg.Values["seq"])
if seq <= checkpoint.LastSeq {
continue
}
payloadB64 := getString(msg.Values["yjs_payload"])
payload, err := base64.StdEncoding.DecodeString(payloadB64)
if err != nil {
continue
}
if client.Enqueue(payload) {
sent++
} else {
return
}
}
return
}
}
// Fallback: DB history replay
updates, err := wsh.store.ListUpdateHistoryAfterSeq(ctx, documentID, checkpoint.LastSeq, maxReplayUpdates)
if err != nil {
return
}
for _, upd := range updates {
if sent >= maxReplayUpdates {
log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String())
return
}
if client.Enqueue(upd.Payload) {
sent++
} else {
return
}
}
}
func getString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprint(v)
}
}
func parseInt64(value interface{}) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case uint64:
return int64(v)
case string:
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
return parsed
}
case []byte:
if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil {
return parsed
}
}
return 0
}

View File

@@ -2,6 +2,8 @@ package hub
import ( import (
"context" "context"
"encoding/base64"
"strconv"
"sync" "sync"
"time" "time"
@@ -37,10 +39,11 @@ type Client struct {
idsMu sync.Mutex idsMu sync.Mutex
} }
type Room struct { type Room struct {
ID string ID string
clients map[*Client]bool clients map[*Client]bool
mu sync.RWMutex mu sync.RWMutex
cancel context.CancelFunc cancel context.CancelFunc
reconnectCount int // Track Redis reconnection attempts for debugging
} }
type Hub struct { type Hub struct {
@@ -64,6 +67,10 @@ type Hub struct {
// Bounded worker pool for Redis SetAwareness // Bounded worker pool for Redis SetAwareness
awarenessQueue chan awarenessItem awarenessQueue chan awarenessItem
// Stream persistence worker pool (P1: Redis Streams durability)
streamQueue chan *Message // buffered queue for XADD operations
streamDone chan struct{} // close to signal stream workers to exit
} }
const ( const (
@@ -79,6 +86,13 @@ const (
// awarenessQueueSize is the buffer size for awareness updates. // awarenessQueueSize is the buffer size for awareness updates.
awarenessQueueSize = 4096 awarenessQueueSize = 4096
// streamWorkerCount is the number of fixed goroutines consuming from streamQueue.
// 50 workers match publish workers for consistent throughput.
streamWorkerCount = 50
// streamQueueSize is the buffer size for the stream persistence queue.
streamQueueSize = 4096
) )
type awarenessItem struct { type awarenessItem struct {
@@ -103,11 +117,15 @@ func NewHub(messagebus messagebus.MessageBus, serverID string, logger *zap.Logge
publishDone: make(chan struct{}), publishDone: make(chan struct{}),
// bounded awareness worker pool // bounded awareness worker pool
awarenessQueue: make(chan awarenessItem, awarenessQueueSize), awarenessQueue: make(chan awarenessItem, awarenessQueueSize),
// Stream persistence worker pool
streamQueue: make(chan *Message, streamQueueSize),
streamDone: make(chan struct{}),
} }
// Start the fixed worker pool for Redis publishing // Start the fixed worker pool for Redis publishing
h.startPublishWorkers(publishWorkerCount) h.startPublishWorkers(publishWorkerCount)
h.startAwarenessWorkers(awarenessWorkerCount) h.startAwarenessWorkers(awarenessWorkerCount)
h.startStreamWorkers(streamWorkerCount)
return h return h
} }
@@ -173,6 +191,82 @@ func (h *Hub) startAwarenessWorkers(n int) {
h.logger.Info("Awareness worker pool started", zap.Int("workers", n)) h.logger.Info("Awareness worker pool started", zap.Int("workers", n))
} }
// startStreamWorkers launches n goroutines that consume from streamQueue
// and add messages to Redis Streams for durability and replay.
func (h *Hub) startStreamWorkers(n int) {
for i := 0; i < n; i++ {
go func(workerID int) {
for {
select {
case <-h.streamDone:
h.logger.Info("Stream worker exiting", zap.Int("worker_id", workerID))
return
case msg, ok := <-h.streamQueue:
if !ok {
return
}
h.addToStream(msg)
}
}
}(i)
}
h.logger.Info("Stream worker pool started", zap.Int("workers", n))
}
// encodeBase64 encodes binary data to base64 string for Redis storage
func encodeBase64(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}
// addToStream adds a message to Redis Streams for durability
func (h *Hub) addToStream(msg *Message) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
streamKey := "stream:" + msg.RoomID
// Get next sequence number atomically
seqKey := "seq:" + msg.RoomID
seq, err := h.messagebus.Incr(ctx, seqKey)
if err != nil {
h.logger.Error("Failed to increment sequence",
zap.String("room_id", msg.RoomID),
zap.Error(err))
return
}
// Encode payload as base64 (binary-safe storage)
payload := encodeBase64(msg.Data)
// Extract Yjs message type from first byte as numeric string
msgType := "0"
if len(msg.Data) > 0 {
msgType = strconv.Itoa(int(msg.Data[0]))
}
// Add entry to Stream with MAXLEN trimming
values := map[string]interface{}{
"type": "update",
"server_id": h.serverID,
"yjs_payload": payload,
"msg_type": msgType,
"seq": seq,
"timestamp": time.Now().Format(time.RFC3339),
}
_, err = h.messagebus.XAdd(ctx, streamKey, 10000, true, values)
if err != nil {
h.logger.Error("Failed to add to Stream",
zap.String("stream_key", streamKey),
zap.Int64("seq", seq),
zap.Error(err))
return
}
// Mark this document as active so the persist worker only processes active streams
_ = h.messagebus.ZAdd(ctx, "active-streams", float64(time.Now().Unix()), msg.RoomID)
}
func (h *Hub) Run() { func (h *Hub) Run() {
for { for {
select { select {
@@ -471,6 +565,7 @@ func (h *Hub) broadcastMessage(message *Message) {
// 只有本地客户端发出的消息 (sender != nil) 才推送到 Redis // 只有本地客户端发出的消息 (sender != nil) 才推送到 Redis
// P0 fix: send to bounded worker pool instead of spawning unbounded goroutines // P0 fix: send to bounded worker pool instead of spawning unbounded goroutines
if message.sender != nil && !h.fallbackMode && h.messagebus != nil { if message.sender != nil && !h.fallbackMode && h.messagebus != nil {
// 3a. Publish to Pub/Sub (real-time cross-server broadcast)
select { select {
case h.publishQueue <- message: case h.publishQueue <- message:
// Successfully queued for async publish by worker pool // Successfully queued for async publish by worker pool
@@ -479,6 +574,19 @@ func (h *Hub) broadcastMessage(message *Message) {
h.logger.Warn("Publish queue full, dropping Redis publish", h.logger.Warn("Publish queue full, dropping Redis publish",
zap.String("room_id", message.RoomID)) zap.String("room_id", message.RoomID))
} }
// 3b. Add to Stream for durability (only Type 0 updates, not Type 1 awareness)
// Type 0 = Yjs sync/update messages (document changes)
// Type 1 = Yjs awareness messages (cursors, presence) - ephemeral, skip
if len(message.Data) > 0 && message.Data[0] == 0 {
select {
case h.streamQueue <- message:
// Successfully queued for async Stream add
default:
h.logger.Warn("Stream queue full, dropping durability",
zap.String("room_id", message.RoomID))
}
}
} }
} }
@@ -504,10 +612,28 @@ func (h *Hub) broadcastToLocalClients(room *Room, data []byte, sender *Client) {
} }
} }
func (h *Hub) startRoomMessageForwarding(ctx context.Context, roomID string, msgChan <-chan []byte) { func (h *Hub) startRoomMessageForwarding(ctx context.Context, roomID string, msgChan <-chan []byte) {
h.logger.Info("Starting message forwarding from Redis to room", // Increment and log reconnection count for debugging
zap.String("room_id", roomID), h.mu.RLock()
zap.String("server_id", h.serverID), room, exists := h.rooms[roomID]
) h.mu.RUnlock()
if exists {
room.mu.Lock()
room.reconnectCount++
reconnectCount := room.reconnectCount
room.mu.Unlock()
h.logger.Info("Starting message forwarding from Redis to room",
zap.String("room_id", roomID),
zap.String("server_id", h.serverID),
zap.Int("reconnect_count", reconnectCount),
)
} else {
h.logger.Info("Starting message forwarding from Redis to room",
zap.String("room_id", roomID),
zap.String("server_id", h.serverID),
)
}
for { for {
select { select {
@@ -791,12 +917,28 @@ func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string
UserAvatar: userAvatar, UserAvatar: userAvatar,
Permission: permission, Permission: permission,
Conn: conn, Conn: conn,
send: make(chan []byte, 1024), send: make(chan []byte, 8192),
hub: hub, hub: hub,
roomID: roomID, roomID: roomID,
observedYjsIDs: make(map[uint64]uint64), observedYjsIDs: make(map[uint64]uint64),
} }
} }
// Enqueue sends a message to the client send buffer (non-blocking).
// Returns false if the buffer is full.
func (c *Client) Enqueue(message []byte) bool {
select {
case c.send <- message:
return true
default:
if c.hub != nil && c.hub.logger != nil {
c.hub.logger.Warn("Client send buffer full during replay",
zap.String("client_id", c.ID),
zap.String("room_id", c.roomID))
}
return false
}
}
func (c *Client) unregister() { func (c *Client) unregister() {
c.unregisterOnce.Do(func() { c.unregisterOnce.Do(func() {
c.hub.Unregister <- c c.hub.Unregister <- c

View File

@@ -23,7 +23,7 @@ func NewLogger(isDevelopment bool) (*zap.Logger, error) {
// 👇 关键修改:直接拉到 Fatal 级别 // 👇 关键修改:直接拉到 Fatal 级别
// 这样 Error, Warn, Info, Debug 全部都会被忽略 // 这样 Error, Warn, Info, Debug 全部都会被忽略
// 彻底消除 IO 锁竞争 // 彻底消除 IO 锁竞争
config.Level = zap.NewAtomicLevelAt(zapcore.FatalLevel) config.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
logger, err := config.Build() logger, err := config.Build()
if err != nil { if err != nil {

View File

@@ -2,6 +2,7 @@ package messagebus
import ( import (
"context" "context"
"time"
) )
// MessageBus abstracts message distribution across server instances // MessageBus abstracts message distribution across server instances
@@ -33,6 +34,72 @@ type MessageBus interface {
// Close gracefully shuts down the message bus // Close gracefully shuts down the message bus
Close() error Close() error
// ========== Redis Streams Operations ==========
// XAdd adds a new entry to a stream with optional MAXLEN trimming
XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error)
// XReadGroup reads messages from a stream using a consumer group
XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error)
// XAck acknowledges one or more messages from a consumer group
XAck(ctx context.Context, stream, group string, ids ...string) (int64, error)
// XGroupCreate creates a new consumer group for a stream
XGroupCreate(ctx context.Context, stream, group, start string) error
// XGroupCreateMkStream creates a consumer group and the stream if it doesn't exist
XGroupCreateMkStream(ctx context.Context, stream, group, start string) error
// XPending returns pending messages information for a consumer group
XPending(ctx context.Context, stream, group string) (*PendingInfo, error)
// XClaim claims pending messages from a consumer group
XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error)
// XAutoClaim claims pending messages automatically (Redis >= 6.2)
// Returns claimed messages and next start ID.
XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error)
// XRange reads a range of messages from a stream
XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error)
// XTrimMinID trims a stream to a minimum ID (time-based retention)
XTrimMinID(ctx context.Context, stream, minID string) (int64, error)
// Incr increments a counter atomically (for sequence numbers)
Incr(ctx context.Context, key string) (int64, error)
// ========== Sorted Set (ZSET) Operations ==========
// ZAdd adds a member with a score to a sorted set (used for active-stream tracking)
ZAdd(ctx context.Context, key string, score float64, member string) error
// ZRangeByScore returns members with scores between min and max
ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error)
// ZRemRangeByScore removes members with scores between min and max
ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error)
// Distributed lock helpers (used by background workers)
AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error)
RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error)
ReleaseLock(ctx context.Context, key string) error
}
// StreamMessage represents a message from a Redis Stream
type StreamMessage struct {
ID string
Values map[string]interface{}
}
// PendingInfo contains information about pending messages in a consumer group
type PendingInfo struct {
Count int64
Lower string
Upper string
Consumers map[string]int64
} }
// LocalMessageBus is a no-op implementation for single-server mode // LocalMessageBus is a no-op implementation for single-server mode
@@ -78,3 +145,73 @@ func (l *LocalMessageBus) IsHealthy() bool {
func (l *LocalMessageBus) Close() error { func (l *LocalMessageBus) Close() error {
return nil return nil
} }
// ========== Redis Streams Operations (No-op for local mode) ==========
func (l *LocalMessageBus) XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) {
return "0-0", nil
}
func (l *LocalMessageBus) XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) {
return nil, nil
}
func (l *LocalMessageBus) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) {
return 0, nil
}
func (l *LocalMessageBus) XGroupCreate(ctx context.Context, stream, group, start string) error {
return nil
}
func (l *LocalMessageBus) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error {
return nil
}
func (l *LocalMessageBus) XPending(ctx context.Context, stream, group string) (*PendingInfo, error) {
return &PendingInfo{}, nil
}
func (l *LocalMessageBus) XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) {
return nil, nil
}
func (l *LocalMessageBus) XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) {
return nil, "0-0", nil
}
func (l *LocalMessageBus) XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) {
return nil, nil
}
func (l *LocalMessageBus) XTrimMinID(ctx context.Context, stream, minID string) (int64, error) {
return 0, nil
}
func (l *LocalMessageBus) Incr(ctx context.Context, key string) (int64, error) {
return 0, nil
}
func (l *LocalMessageBus) ZAdd(ctx context.Context, key string, score float64, member string) error {
return nil
}
func (l *LocalMessageBus) ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) {
return nil, nil
}
func (l *LocalMessageBus) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) {
return 0, nil
}
func (l *LocalMessageBus) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
return true, nil
}
func (l *LocalMessageBus) RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
return true, nil
}
func (l *LocalMessageBus) ReleaseLock(ctx context.Context, key string) error {
return nil
}

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"log" "log"
"net"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -88,6 +89,23 @@ func NewRedisMessageBus(redisURL string, serverID string, logger *zap.Logger) (*
// - Redis will handle stale connections via TCP keepalive // - Redis will handle stale connections via TCP keepalive
opts.ConnMaxLifetime = 1 * time.Hour opts.ConnMaxLifetime = 1 * time.Hour
// ================================
// Socket-Level Timeout Configuration (prevents indefinite hangs)
// ================================
// Without these, TCP reads/writes block indefinitely when Redis is unresponsive,
// causing OS-level timeouts (60-120s) instead of application-level control.
// DialTimeout: How long to wait for initial connection establishment
opts.DialTimeout = 5 * time.Second
// ReadTimeout: Maximum time for socket read operations
// - 30s is appropriate for PubSub (long intervals between messages are normal)
// - Prevents indefinite blocking when Redis hangs
opts.ReadTimeout = 30 * time.Second
// WriteTimeout: Maximum time for socket write operations
opts.WriteTimeout = 5 * time.Second
client := goredis.NewClient(opts) client := goredis.NewClient(opts)
// ================================ // ================================
@@ -215,12 +233,15 @@ func (r *RedisMessageBus) readLoop(ctx context.Context, roomID string, sub *subs
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
r.logger.Warn("PubSub initial subscription failed, retrying with backoff",
zap.String("roomID", roomID),
zap.Error(err),
zap.Duration("backoff", backoff),
)
time.Sleep(backoff) time.Sleep(backoff)
if backoff < maxBackoff { backoff = backoff * 2
backoff *= 2 if backoff > maxBackoff {
if backoff > maxBackoff { backoff = maxBackoff
backoff = maxBackoff
}
} }
continue continue
} }
@@ -242,12 +263,15 @@ func (r *RedisMessageBus) readLoop(ctx context.Context, roomID string, sub *subs
if ctx.Err() != nil { if ctx.Err() != nil {
return return
} }
r.logger.Warn("PubSub receive failed, retrying with backoff",
zap.String("roomID", roomID),
zap.Error(err),
zap.Duration("backoff", backoff),
)
time.Sleep(backoff) time.Sleep(backoff)
if backoff < maxBackoff { backoff = backoff * 2
backoff *= 2 if backoff > maxBackoff {
if backoff > maxBackoff { backoff = maxBackoff
backoff = maxBackoff
}
} }
} }
} }
@@ -261,12 +285,15 @@ func (r *RedisMessageBus) receiveOnce(ctx context.Context, roomID string, pubsub
msg, err := pubsub.ReceiveTimeout(ctx, 5*time.Second) msg, err := pubsub.ReceiveTimeout(ctx, 5*time.Second)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { if ctx.Err() != nil {
return err return ctx.Err()
} }
if errors.Is(err, goredis.Nil) { if errors.Is(err, goredis.Nil) {
continue continue
} }
if isTimeoutErr(err) {
continue
}
r.logger.Warn("pubsub receive error, closing subscription", r.logger.Warn("pubsub receive error, closing subscription",
zap.String("roomID", roomID), zap.String("roomID", roomID),
zap.Error(err), zap.Error(err),
@@ -308,6 +335,17 @@ func (r *RedisMessageBus) receiveOnce(ctx context.Context, roomID string, pubsub
} }
} }
func isTimeoutErr(err error) bool {
if err == nil {
return false
}
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
return errors.As(err, &netErr) && netErr.Timeout()
}
// Unsubscribe stops listening to a room // Unsubscribe stops listening to a room
func (r *RedisMessageBus) Unsubscribe(ctx context.Context, roomID string) error { func (r *RedisMessageBus) Unsubscribe(ctx context.Context, roomID string) error {
r.subMu.Lock() r.subMu.Lock()
@@ -430,7 +468,7 @@ func (r *RedisMessageBus) DeleteAwareness(ctx context.Context, roomID string, cl
// IsHealthy checks Redis connectivity // IsHealthy checks Redis connectivity
func (r *RedisMessageBus) IsHealthy() bool { func (r *RedisMessageBus) IsHealthy() bool {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
// 只有 Ping 成功且没有报错,才认为服务是健康的 // 只有 Ping 成功且没有报错,才认为服务是健康的
@@ -516,3 +554,223 @@ func (r *RedisMessageBus) ClearAllAwareness(ctx context.Context, roomID string)
// 直接使用 Del 命令删除整个 Key // 直接使用 Del 命令删除整个 Key
return r.client.Del(ctx, key).Err() return r.client.Del(ctx, key).Err()
} }
// ========== Redis Streams Operations ==========
// XAdd adds a new entry to a stream with optional MAXLEN trimming
func (r *RedisMessageBus) XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) {
result := r.client.XAdd(ctx, &goredis.XAddArgs{
Stream: stream,
MaxLen: maxLen,
Approx: approx,
Values: values,
})
return result.Val(), result.Err()
}
// XReadGroup reads messages from a stream using a consumer group
func (r *RedisMessageBus) XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) {
result := r.client.XReadGroup(ctx, &goredis.XReadGroupArgs{
Group: group,
Consumer: consumer,
Streams: streams,
Count: count,
Block: block,
})
if err := result.Err(); err != nil {
// Timeout is not an error, just no new messages
if err == goredis.Nil {
return nil, nil
}
return nil, err
}
// Convert go-redis XStream to our StreamMessage format
var messages []StreamMessage
for _, stream := range result.Val() {
for _, msg := range stream.Messages {
messages = append(messages, StreamMessage{
ID: msg.ID,
Values: msg.Values,
})
}
}
return messages, nil
}
// XAck acknowledges one or more messages from a consumer group
func (r *RedisMessageBus) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) {
result := r.client.XAck(ctx, stream, group, ids...)
return result.Val(), result.Err()
}
// XGroupCreate creates a new consumer group for a stream
func (r *RedisMessageBus) XGroupCreate(ctx context.Context, stream, group, start string) error {
return r.client.XGroupCreate(ctx, stream, group, start).Err()
}
// XGroupCreateMkStream creates a consumer group and the stream if it doesn't exist
func (r *RedisMessageBus) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error {
return r.client.XGroupCreateMkStream(ctx, stream, group, start).Err()
}
// XPending returns pending messages information for a consumer group
func (r *RedisMessageBus) XPending(ctx context.Context, stream, group string) (*PendingInfo, error) {
result := r.client.XPending(ctx, stream, group)
if err := result.Err(); err != nil {
return nil, err
}
pending := result.Val()
consumers := make(map[string]int64)
for name, count := range pending.Consumers {
consumers[name] = count
}
return &PendingInfo{
Count: pending.Count,
Lower: pending.Lower,
Upper: pending.Higher, // go-redis uses "Higher" instead of "Upper"
Consumers: consumers,
}, nil
}
// XClaim claims pending messages from a consumer group
func (r *RedisMessageBus) XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) {
result := r.client.XClaim(ctx, &goredis.XClaimArgs{
Stream: stream,
Group: group,
Consumer: consumer,
MinIdle: minIdleTime,
Messages: ids,
})
if err := result.Err(); err != nil {
return nil, err
}
// Convert go-redis XMessage to our StreamMessage format
var messages []StreamMessage
for _, msg := range result.Val() {
messages = append(messages, StreamMessage{
ID: msg.ID,
Values: msg.Values,
})
}
return messages, nil
}
// XAutoClaim claims pending messages automatically (Redis >= 6.2)
func (r *RedisMessageBus) XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) {
result := r.client.XAutoClaim(ctx, &goredis.XAutoClaimArgs{
Stream: stream,
Group: group,
Consumer: consumer,
MinIdle: minIdleTime,
Start: start,
Count: count,
})
msgs, nextStart, err := result.Result()
if err != nil {
return nil, "", err
}
messages := make([]StreamMessage, 0, len(msgs))
for _, msg := range msgs {
messages = append(messages, StreamMessage{
ID: msg.ID,
Values: msg.Values,
})
}
return messages, nextStart, nil
}
// XRange reads a range of messages from a stream
func (r *RedisMessageBus) XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) {
result := r.client.XRange(ctx, stream, start, end)
if err := result.Err(); err != nil {
return nil, err
}
// Convert go-redis XMessage to our StreamMessage format
var messages []StreamMessage
for _, msg := range result.Val() {
messages = append(messages, StreamMessage{
ID: msg.ID,
Values: msg.Values,
})
}
return messages, nil
}
// XTrimMinID trims a stream to a minimum ID (time-based retention)
func (r *RedisMessageBus) XTrimMinID(ctx context.Context, stream, minID string) (int64, error) {
// Use XTRIM with MINID and approximation (~) for efficiency
// LIMIT clause prevents blocking Redis during large trims
result := r.client.Do(ctx, "XTRIM", stream, "MINID", "~", minID, "LIMIT", 1000)
if err := result.Err(); err != nil {
return 0, err
}
// Result is the number of entries removed
trimmed, err := result.Int64()
if err != nil {
return 0, err
}
return trimmed, nil
}
// ========== Sorted Set (ZSET) Operations ==========
// ZAdd adds a member with a score to a sorted set
func (r *RedisMessageBus) ZAdd(ctx context.Context, key string, score float64, member string) error {
return r.client.ZAdd(ctx, key, goredis.Z{Score: score, Member: member}).Err()
}
// ZRangeByScore returns members with scores between min and max
func (r *RedisMessageBus) ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) {
return r.client.ZRangeByScore(ctx, key, &goredis.ZRangeBy{
Min: strconv.FormatFloat(min, 'f', -1, 64),
Max: strconv.FormatFloat(max, 'f', -1, 64),
}).Result()
}
// ZRemRangeByScore removes members with scores between min and max
func (r *RedisMessageBus) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) {
return r.client.ZRemRangeByScore(ctx, key,
strconv.FormatFloat(min, 'f', -1, 64),
strconv.FormatFloat(max, 'f', -1, 64),
).Result()
}
// Incr increments a counter atomically (for sequence numbers)
func (r *RedisMessageBus) Incr(ctx context.Context, key string) (int64, error) {
result := r.client.Incr(ctx, key)
return result.Val(), result.Err()
}
// AcquireLock attempts to acquire a distributed lock with TTL
func (r *RedisMessageBus) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
return r.client.SetNX(ctx, key, r.serverID, ttl).Result()
}
// RefreshLock extends the TTL on an existing lock
func (r *RedisMessageBus) RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
result := r.client.SetArgs(ctx, key, r.serverID, goredis.SetArgs{
Mode: "XX",
TTL: ttl,
})
if err := result.Err(); err != nil {
return false, err
}
return result.Val() == "OK", nil
}
// ReleaseLock releases a distributed lock
func (r *RedisMessageBus) ReleaseLock(ctx context.Context, key string) error {
return r.client.Del(ctx, key).Err()
}

View File

@@ -0,0 +1,15 @@
package models
import (
"time"
"github.com/google/uuid"
)
// StreamCheckpoint tracks the last processed Redis Stream entry per document
type StreamCheckpoint struct {
DocumentID uuid.UUID `json:"document_id"`
LastStreamID string `json:"last_stream_id"`
LastSeq int64 `json:"last_seq"`
UpdatedAt time.Time `json:"updated_at"`
}

View File

@@ -53,6 +53,15 @@ type Store interface {
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error) GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error) GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
// Stream checkpoint operations
UpsertStreamCheckpoint(ctx context.Context, documentID uuid.UUID, streamID string, seq int64) error
GetStreamCheckpoint(ctx context.Context, documentID uuid.UUID) (*models.StreamCheckpoint, error)
// Update history (WAL) operations
InsertUpdateHistoryBatch(ctx context.Context, entries []UpdateHistoryEntry) error
ListUpdateHistoryAfterSeq(ctx context.Context, documentID uuid.UUID, afterSeq int64, limit int) ([]UpdateHistoryEntry, error)
DeleteUpdateHistoryUpToSeq(ctx context.Context, documentID uuid.UUID, maxSeq int64) error
Close() error Close() error
} }

View File

@@ -0,0 +1,46 @@
package store
import (
"context"
"fmt"
"github.com/M1ngdaXie/realtime-collab/internal/models"
"github.com/google/uuid"
)
// UpsertStreamCheckpoint creates or updates the stream checkpoint for a document
func (s *PostgresStore) UpsertStreamCheckpoint(ctx context.Context, documentID uuid.UUID, streamID string, seq int64) error {
query := `
INSERT INTO stream_checkpoints (document_id, last_stream_id, last_seq, updated_at)
VALUES ($1, $2, $3, NOW())
ON CONFLICT (document_id)
DO UPDATE SET last_stream_id = EXCLUDED.last_stream_id,
last_seq = EXCLUDED.last_seq,
updated_at = NOW()
`
if _, err := s.db.ExecContext(ctx, query, documentID, streamID, seq); err != nil {
return fmt.Errorf("failed to upsert stream checkpoint: %w", err)
}
return nil
}
// GetStreamCheckpoint retrieves the stream checkpoint for a document
func (s *PostgresStore) GetStreamCheckpoint(ctx context.Context, documentID uuid.UUID) (*models.StreamCheckpoint, error) {
query := `
SELECT document_id, last_stream_id, last_seq, updated_at
FROM stream_checkpoints
WHERE document_id = $1
`
var checkpoint models.StreamCheckpoint
if err := s.db.QueryRowContext(ctx, query, documentID).Scan(
&checkpoint.DocumentID,
&checkpoint.LastStreamID,
&checkpoint.LastSeq,
&checkpoint.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("failed to get stream checkpoint: %w", err)
}
return &checkpoint, nil
}

View File

@@ -71,10 +71,14 @@ func SetupTestDB(t *testing.T) (*PostgresStore, func()) {
// Run migrations // Run migrations
scriptsDir := filepath.Join("..", "..", "scripts") scriptsDir := filepath.Join("..", "..", "scripts")
migrations := []string{ migrations := []string{
"init.sql", "000_extensions.sql",
"001_add_users_and_sessions.sql", "001_init_schema.sql",
"002_add_document_shares.sql", "002_add_users_and_sessions.sql",
"003_add_public_sharing.sql", "003_add_document_shares.sql",
"004_add_public_sharing.sql",
"005_add_share_link_permission.sql",
"010_add_stream_checkpoints.sql",
"011_add_update_history.sql",
} }
for _, migration := range migrations { for _, migration := range migrations {
@@ -107,6 +111,8 @@ func SetupTestDB(t *testing.T) (*PostgresStore, func()) {
func TruncateAllTables(ctx context.Context, store *PostgresStore) error { func TruncateAllTables(ctx context.Context, store *PostgresStore) error {
tables := []string{ tables := []string{
"document_updates", "document_updates",
"document_update_history",
"stream_checkpoints",
"document_shares", "document_shares",
"sessions", "sessions",
"documents", "documents",

View File

@@ -0,0 +1,115 @@
package store
import (
"context"
"fmt"
"strings"
"time"
"unicode/utf8"
"github.com/google/uuid"
)
// UpdateHistoryEntry represents a persisted update from Redis Streams
// used for recovery and replay.
type UpdateHistoryEntry struct {
DocumentID uuid.UUID
StreamID string
Seq int64
Payload []byte
MsgType string
ServerID string
CreatedAt time.Time
}
// InsertUpdateHistoryBatch inserts update history entries in a single batch.
// Uses ON CONFLICT DO NOTHING to make inserts idempotent.
func (s *PostgresStore) InsertUpdateHistoryBatch(ctx context.Context, entries []UpdateHistoryEntry) error {
if len(entries) == 0 {
return nil
}
var sb strings.Builder
sb.WriteString("INSERT INTO document_update_history (document_id, stream_id, seq, payload, msg_type, server_id, created_at) VALUES ")
args := make([]interface{}, 0, len(entries)*7)
for i, e := range entries {
if i > 0 {
sb.WriteString(",")
}
base := i*7 + 1
sb.WriteString(fmt.Sprintf("($%d,$%d,$%d,$%d,$%d,$%d,$%d)", base, base+1, base+2, base+3, base+4, base+5, base+6))
msgType := sanitizeTextForDB(e.MsgType)
serverID := sanitizeTextForDB(e.ServerID)
args = append(args, e.DocumentID, e.StreamID, e.Seq, e.Payload, nullIfEmpty(msgType), nullIfEmpty(serverID), e.CreatedAt)
}
// Idempotent insert
sb.WriteString(" ON CONFLICT (document_id, stream_id) DO NOTHING")
if _, err := s.db.ExecContext(ctx, sb.String(), args...); err != nil {
return fmt.Errorf("failed to insert update history batch: %w", err)
}
return nil
}
// ListUpdateHistoryAfterSeq returns updates with seq greater than afterSeq, ordered by seq.
func (s *PostgresStore) ListUpdateHistoryAfterSeq(ctx context.Context, documentID uuid.UUID, afterSeq int64, limit int) ([]UpdateHistoryEntry, error) {
if limit <= 0 {
limit = 1000
}
query := `
SELECT document_id, stream_id, seq, payload, COALESCE(msg_type, ''), COALESCE(server_id, ''), created_at
FROM document_update_history
WHERE document_id = $1 AND seq > $2
ORDER BY seq ASC
LIMIT $3
`
rows, err := s.db.QueryContext(ctx, query, documentID, afterSeq, limit)
if err != nil {
return nil, fmt.Errorf("failed to list update history: %w", err)
}
defer rows.Close()
var results []UpdateHistoryEntry
for rows.Next() {
var e UpdateHistoryEntry
if err := rows.Scan(&e.DocumentID, &e.StreamID, &e.Seq, &e.Payload, &e.MsgType, &e.ServerID, &e.CreatedAt); err != nil {
return nil, fmt.Errorf("failed to scan update history: %w", err)
}
results = append(results, e)
}
return results, nil
}
// DeleteUpdateHistoryUpToSeq deletes updates with seq <= maxSeq for a document.
func (s *PostgresStore) DeleteUpdateHistoryUpToSeq(ctx context.Context, documentID uuid.UUID, maxSeq int64) error {
query := `
DELETE FROM document_update_history
WHERE document_id = $1 AND seq <= $2
`
if _, err := s.db.ExecContext(ctx, query, documentID, maxSeq); err != nil {
return fmt.Errorf("failed to delete update history: %w", err)
}
return nil
}
func nullIfEmpty(s string) interface{} {
if s == "" {
return nil
}
return s
}
func sanitizeTextForDB(s string) string {
if s == "" {
return ""
}
if strings.IndexByte(s, 0) >= 0 {
return ""
}
if !utf8.ValidString(s) {
return ""
}
return s
}

View File

@@ -0,0 +1,320 @@
package workers
import (
"context"
"encoding/base64"
"fmt"
"runtime/debug"
"strconv"
"strings"
"time"
"unicode/utf8"
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
"github.com/M1ngdaXie/realtime-collab/internal/store"
"github.com/google/uuid"
"go.uber.org/zap"
)
const (
updatePersistGroupName = "update-persist-worker"
updatePersistLockKey = "lock:update-persist-worker"
updatePersistLockTTL = 30 * time.Second
updatePersistTick = 2 * time.Second
updateReadCount = 200
updateReadBlock = -1 // negative → go-redis omits BLOCK clause → non-blocking
updateBatchSize = 500
updateSafeSeqLag = int64(1000)
updateAutoClaimIdle = 30 * time.Second
updateHeartbeatEvery = 30 * time.Second
)
// StartUpdatePersistWorker persists Redis Stream updates into Postgres for recovery.
func StartUpdatePersistWorker(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) {
if msgBus == nil || dbStore == nil {
return
}
for {
func() {
defer func() {
if r := recover(); r != nil {
logWorker(logger, "Update persist worker panic",
zap.Any("panic", r),
zap.ByteString("stack", debug.Stack()))
}
}()
select {
case <-ctx.Done():
return
default:
}
acquired, err := msgBus.AcquireLock(ctx, updatePersistLockKey, updatePersistLockTTL)
if err != nil {
logWorker(logger, "Failed to acquire update persist worker lock", zap.Error(err))
time.Sleep(updatePersistTick)
return
}
if !acquired {
time.Sleep(updatePersistTick)
return
}
logWorker(logger, "Update persist worker lock acquired", zap.String("server_id", serverID))
runUpdatePersistWorker(ctx, msgBus, dbStore, logger, serverID)
}()
select {
case <-ctx.Done():
return
default:
}
// If the worker exited (including panic), pause briefly before retry.
time.Sleep(updatePersistTick)
}
}
func runUpdatePersistWorker(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) {
ticker := time.NewTicker(updatePersistTick)
defer ticker.Stop()
refreshTicker := time.NewTicker(updatePersistLockTTL / 2)
defer refreshTicker.Stop()
heartbeatTicker := time.NewTicker(updateHeartbeatEvery)
defer heartbeatTicker.Stop()
for {
select {
case <-ctx.Done():
_ = msgBus.ReleaseLock(ctx, updatePersistLockKey)
return
case <-refreshTicker.C:
ok, err := msgBus.RefreshLock(ctx, updatePersistLockKey, updatePersistLockTTL)
if err != nil || !ok {
logWorker(logger, "Update persist worker lock lost", zap.Error(err))
_ = msgBus.ReleaseLock(ctx, updatePersistLockKey)
return
}
case <-heartbeatTicker.C:
logWorker(logger, "Update persist worker heartbeat", zap.String("server_id", serverID))
case <-ticker.C:
if err := processUpdatePersistence(ctx, msgBus, dbStore, logger, serverID); err != nil {
logWorker(logger, "Update persist worker tick failed", zap.Error(err))
}
}
}
}
func processUpdatePersistence(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) error {
// Only process documents with recent stream activity (active in the last 60 seconds)
cutoff := float64(time.Now().Add(-60 * time.Second).Unix())
activeDocIDs, err := msgBus.ZRangeByScore(ctx, "active-streams", cutoff, float64(time.Now().Unix()))
if err != nil {
return fmt.Errorf("failed to get active streams: %w", err)
}
// Prune stale entries older than 5 minutes (best-effort cleanup)
stale := float64(time.Now().Add(-5 * time.Minute).Unix())
if _, err := msgBus.ZRemRangeByScore(ctx, "active-streams", 0, stale); err != nil {
logWorker(logger, "Failed to prune stale active-streams entries", zap.Error(err))
}
for _, docIDStr := range activeDocIDs {
docID, err := uuid.Parse(docIDStr)
if err != nil {
logWorker(logger, "Invalid document ID in active-streams", zap.String("doc_id", docIDStr))
continue
}
streamKey := "stream:" + docIDStr
if err := ensureConsumerGroup(ctx, msgBus, streamKey, updatePersistGroupName); err != nil {
logWorker(logger, "Failed to ensure update persist consumer group", zap.String("stream", streamKey), zap.Error(err))
continue
}
var ackIDs []string
docEntries := make([]store.UpdateHistoryEntry, 0, updateBatchSize)
// First, try to claim idle pending messages (e.g., from previous crashes)
claimed, _, err := msgBus.XAutoClaim(ctx, streamKey, updatePersistGroupName, serverID, updateAutoClaimIdle, "0-0", updateReadCount)
if err != nil {
logWorker(logger, "XAutoClaim failed", zap.String("stream", streamKey), zap.Error(err))
} else if len(claimed) > 0 {
collectStreamMessages(ctx, msgBus, dbStore, logger, docID, streamKey, claimed, &docEntries, &ackIDs)
}
messages, err := msgBus.XReadGroup(ctx, updatePersistGroupName, serverID, []string{streamKey, ">"}, updateReadCount, updateReadBlock)
if err != nil {
logWorker(logger, "XReadGroup failed", zap.String("stream", streamKey), zap.Error(err))
continue
}
if len(messages) > 0 {
collectStreamMessages(ctx, msgBus, dbStore, logger, docID, streamKey, messages, &docEntries, &ackIDs)
}
if len(docEntries) > 0 {
if err := dbStore.InsertUpdateHistoryBatch(ctx, docEntries); err != nil {
logWorker(logger, "Failed to insert update history batch", zap.Error(err))
// Skip ACK to retry on next tick
continue
}
}
if len(ackIDs) > 0 {
if _, err := msgBus.XAck(ctx, streamKey, updatePersistGroupName, ackIDs...); err != nil {
logWorker(logger, "XAck failed", zap.String("stream", streamKey), zap.Error(err))
}
}
}
return nil
}
func collectStreamMessages(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, documentID uuid.UUID, streamKey string, messages []messagebus.StreamMessage, docEntries *[]store.UpdateHistoryEntry, ackIDs *[]string) {
for _, msg := range messages {
msgType := getString(msg.Values["type"])
switch msgType {
case "update":
payloadB64 := getString(msg.Values["yjs_payload"])
payload, err := base64.StdEncoding.DecodeString(payloadB64)
if err != nil {
logWorker(logger, "Failed to decode update payload",
zap.String("stream", streamKey),
zap.String("stream_id", msg.ID),
zap.Error(err))
continue
}
seq := parseInt64(msg.Values["seq"])
msgType := normalizeMsgType(msg.Values["msg_type"])
serverID := sanitizeText(getString(msg.Values["server_id"]))
entry := store.UpdateHistoryEntry{
DocumentID: documentID,
StreamID: msg.ID,
Seq: seq,
Payload: payload,
MsgType: msgType,
ServerID: serverID,
CreatedAt: time.Now().UTC(),
}
*docEntries = append(*docEntries, entry)
case "snapshot":
seq := parseInt64(msg.Values["seq"])
if seq > 0 {
if err := dbStore.UpsertStreamCheckpoint(ctx, documentID, msg.ID, seq); err != nil {
logWorker(logger, "Failed to upsert stream checkpoint from snapshot marker",
zap.String("document_id", documentID.String()),
zap.Error(err))
}
// Retention: prune DB history based on checkpoint (best-effort)
maxSeq := seq - updateSafeSeqLag
if maxSeq > 0 {
if err := dbStore.DeleteUpdateHistoryUpToSeq(ctx, documentID, maxSeq); err != nil {
logWorker(logger, "Failed to prune update history",
zap.String("document_id", documentID.String()),
zap.Error(err))
}
}
// Trim Redis stream to avoid unbounded growth (best-effort)
if _, err := msgBus.XTrimMinID(ctx, streamKey, msg.ID); err != nil {
logWorker(logger, "Failed to trim Redis stream",
zap.String("stream", streamKey),
zap.Error(err))
}
}
}
*ackIDs = append(*ackIDs, msg.ID)
}
}
func ensureConsumerGroup(ctx context.Context, msgBus messagebus.MessageBus, streamKey, group string) error {
if err := msgBus.XGroupCreateMkStream(ctx, streamKey, group, "0-0"); err != nil {
if !isBusyGroup(err) {
return err
}
}
return nil
}
func isBusyGroup(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "BUSYGROUP")
}
func getString(value interface{}) string {
switch v := value.(type) {
case string:
return v
case []byte:
return string(v)
default:
return fmt.Sprint(v)
}
}
func parseInt64(value interface{}) int64 {
switch v := value.(type) {
case int64:
return v
case int:
return int64(v)
case uint64:
return int64(v)
case string:
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
return parsed
}
case []byte:
if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil {
return parsed
}
}
return 0
}
func sanitizeText(s string) string {
if s == "" {
return s
}
if strings.IndexByte(s, 0) >= 0 {
return ""
}
if !utf8.ValidString(s) {
return ""
}
return s
}
func normalizeMsgType(value interface{}) string {
switch v := value.(type) {
case string:
if v == "" {
return ""
}
if len(v) == 1 {
return strconv.Itoa(int(v[0]))
}
return sanitizeText(v)
case []byte:
if len(v) == 0 {
return ""
}
if len(v) == 1 {
return strconv.Itoa(int(v[0]))
}
return sanitizeText(string(v))
default:
return sanitizeText(fmt.Sprint(v))
}
}
func logWorker(logger *zap.Logger, msg string, fields ...zap.Field) {
if logger == nil {
return
}
logger.Info(msg, fields...)
}

View File

@@ -0,0 +1,12 @@
-- Migration: Add stream checkpoints table for Redis Streams durability
-- This table tracks last processed stream position per document
CREATE TABLE IF NOT EXISTS stream_checkpoints (
document_id UUID PRIMARY KEY REFERENCES documents(id) ON DELETE CASCADE,
last_stream_id TEXT NOT NULL,
last_seq BIGINT NOT NULL DEFAULT 0,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_stream_checkpoints_updated_at
ON stream_checkpoints(updated_at DESC);

View File

@@ -0,0 +1,22 @@
-- Migration: Add update history table for Redis Stream WAL
-- This table stores per-update payloads for recovery and replay
CREATE TABLE IF NOT EXISTS document_update_history (
id BIGSERIAL PRIMARY KEY,
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
stream_id TEXT NOT NULL,
seq BIGINT NOT NULL,
payload BYTEA NOT NULL,
msg_type TEXT,
server_id TEXT,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_stream_id
ON document_update_history(document_id, stream_id);
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_seq
ON document_update_history(document_id, seq);
CREATE INDEX IF NOT EXISTS idx_update_history_document_seq
ON document_update_history(document_id, seq);

View File

@@ -24,8 +24,11 @@ services:
redis: redis:
image: redis:7-alpine image: redis:7-alpine
container_name: realtime-collab-redis container_name: realtime-collab-redis
command: ["redis-server", "--appendonly", "yes"]
ports: ports:
- "6379:6379" - "6379:6379"
volumes:
- redis_data:/data
healthcheck: healthcheck:
test: ["CMD", "redis-cli", "ping"] test: ["CMD", "redis-cli", "ping"]
interval: 10s interval: 10s
@@ -34,3 +37,4 @@ services:
volumes: volumes:
postgres_data: postgres_data:
redis_data:

View File

@@ -53,8 +53,11 @@ export const documentsApi = {
}, },
// Get document Yjs state // Get document Yjs state
getState: async (id: string): Promise<Uint8Array> => { getState: async (id: string, shareToken?: string): Promise<Uint8Array> => {
const response = await authFetch(`${API_BASE_URL}/documents/${id}/state`); const url = shareToken
? `${API_BASE_URL}/documents/${id}/state?share=${shareToken}`
: `${API_BASE_URL}/documents/${id}/state`;
const response = await authFetch(url);
if (!response.ok) throw new Error("Failed to fetch document state"); if (!response.ok) throw new Error("Failed to fetch document state");
const arrayBuffer = await response.arrayBuffer(); const arrayBuffer = await response.arrayBuffer();
return new Uint8Array(arrayBuffer); return new Uint8Array(arrayBuffer);
@@ -167,4 +170,4 @@ export const versionsApi = {
if (!response.ok) throw new Error('Failed to restore version'); if (!response.ok) throw new Error('Failed to restore version');
return response.json(); return response.json();
}, },
}; };

View File

@@ -157,10 +157,34 @@ export const useYjsDocument = (documentId: string, shareToken?: string) => {
setSynced(true); setSynced(true);
}); });
// Connection stability monitoring with reconnection limits
let reconnectCount = 0;
const maxReconnects = 10;
yjsProviders.websocketProvider.on( yjsProviders.websocketProvider.on(
"status", "status",
(event: { status: string }) => { (event: { status: string }) => {
console.log("WebSocket status:", event.status); console.log("WebSocket status:", event.status);
if (event.status === "disconnected") {
reconnectCount++;
if (reconnectCount >= maxReconnects) {
console.error(
"Max reconnection attempts reached. Please refresh the page."
);
// Could optionally show a user notification here
} else {
console.log(
`Reconnection attempt ${reconnectCount}/${maxReconnects}`
);
}
} else if (event.status === "connected") {
// Reset counter on successful connection
if (reconnectCount > 0) {
console.log("Reconnected successfully, resetting counter");
}
reconnectCount = 0;
}
} }
); );

View File

@@ -30,7 +30,7 @@ export const createYjsDocument = async (
// Load initial state from database BEFORE connecting providers // Load initial state from database BEFORE connecting providers
try { try {
const state = await documentsApi.getState(documentId); const state = await documentsApi.getState(documentId, shareToken);
if (state && state.length > 0) { if (state && state.length > 0) {
Y.applyUpdate(ydoc, state); Y.applyUpdate(ydoc, state);
console.log('✓ Loaded document state from database'); console.log('✓ Loaded document state from database');
@@ -51,7 +51,10 @@ export const createYjsDocument = async (
wsUrl, wsUrl,
documentId, documentId,
ydoc, ydoc,
{ params: wsParams } {
params: wsParams,
maxBackoffTime: 10000, // Max 10s between reconnect attempts
}
); );
// Awareness for cursors and presence // Awareness for cursors and presence