feat: implement Redis Streams support with stream checkpoints and update history
- Added Redis Streams operations to the message bus interface and implementation. - Introduced StreamCheckpoint model to track last processed stream entry per document. - Implemented UpsertStreamCheckpoint and GetStreamCheckpoint methods in the Postgres store. - Created document_update_history table for storing update payloads for recovery and replay. - Developed update persist worker to handle Redis Stream updates and persist them to Postgres. - Enhanced Docker Compose configuration for Redis with persistence. - Updated frontend API to support fetching document state with optional share token. - Added connection stability monitoring in the Yjs document hook.
This commit is contained in:
7
.gitignore
vendored
7
.gitignore
vendored
@@ -34,4 +34,9 @@ build/
|
|||||||
# Docker volumes and data
|
# Docker volumes and data
|
||||||
postgres_data/
|
postgres_data/
|
||||||
|
|
||||||
.claude/
|
.claude/
|
||||||
|
|
||||||
|
#test folder profiles
|
||||||
|
loadtest/pprof
|
||||||
|
|
||||||
|
/docs
|
||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
type contextKey string
|
type contextKey string
|
||||||
@@ -20,41 +21,40 @@ const ContextUserIDKey = "user_id"
|
|||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
store store.Store
|
store store.Store
|
||||||
jwtSecret string
|
jwtSecret string
|
||||||
|
logger *zap.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware creates a new auth middleware
|
// NewAuthMiddleware creates a new auth middleware
|
||||||
func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware {
|
func NewAuthMiddleware(store store.Store, jwtSecret string, logger *zap.Logger) *AuthMiddleware {
|
||||||
|
if logger == nil {
|
||||||
|
logger = zap.NewNop()
|
||||||
|
}
|
||||||
return &AuthMiddleware{
|
return &AuthMiddleware{
|
||||||
store: store,
|
store: store,
|
||||||
jwtSecret: jwtSecret,
|
jwtSecret: jwtSecret,
|
||||||
|
logger: logger,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// RequireAuth middleware requires valid authentication
|
// RequireAuth middleware requires valid authentication
|
||||||
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
|
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
fmt.Println("🔒 RequireAuth: Starting authentication check")
|
|
||||||
|
|
||||||
user, claims, err := m.getUserFromToken(c)
|
user, claims, err := m.getUserFromToken(c)
|
||||||
|
|
||||||
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
|
|
||||||
if claims != nil {
|
|
||||||
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil || user == nil {
|
if err != nil || user == nil {
|
||||||
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user)
|
if err != nil {
|
||||||
|
m.logger.Warn("auth failed",
|
||||||
|
zap.Error(err),
|
||||||
|
zap.String("method", c.Request.Method),
|
||||||
|
zap.String("path", c.FullPath()),
|
||||||
|
)
|
||||||
|
}
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: Name and Email might be empty for old JWT tokens
|
// Note: Name and Email might be empty for old JWT tokens
|
||||||
if claims.Name == "" || claims.Email == "" {
|
|
||||||
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
|
|
||||||
c.Set(ContextUserIDKey, user)
|
c.Set(ContextUserIDKey, user)
|
||||||
c.Set("user_email", claims.Email)
|
c.Set("user_email", claims.Email)
|
||||||
c.Set("user_name", claims.Name)
|
c.Set("user_name", claims.Name)
|
||||||
@@ -88,21 +88,17 @@ func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
|
|||||||
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
|
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
|
||||||
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
|
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
|
||||||
authHeader := c.GetHeader("Authorization")
|
authHeader := c.GetHeader("Authorization")
|
||||||
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
|
|
||||||
|
|
||||||
if authHeader == "" {
|
if authHeader == "" {
|
||||||
fmt.Println("⚠️ getUserFromToken: No Authorization header")
|
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(authHeader, " ")
|
parts := strings.Split(authHeader, " ")
|
||||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||||
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
|
|
||||||
return nil, nil, nil
|
return nil, nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
tokenString := parts[1]
|
tokenString := parts[1]
|
||||||
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
|
|
||||||
|
|
||||||
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||||
// 必须要验证签名算法是 HMAC (HS256)
|
// 必须要验证签名算法是 HMAC (HS256)
|
||||||
@@ -113,7 +109,6 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
|
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,17 +118,14 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
|
|||||||
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
|
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
|
||||||
userID, err := uuid.Parse(claims.Subject)
|
userID, err := uuid.Parse(claims.Subject)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
|
|
||||||
return nil, nil, fmt.Errorf("invalid user ID in token")
|
return nil, nil, fmt.Errorf("invalid user ID in token")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
|
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
|
||||||
// 这一步完全没有查数据库,速度极快
|
// 这一步完全没有查数据库,速度极快
|
||||||
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
|
|
||||||
return &userID, claims, nil
|
return &userID, claims, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
|
|
||||||
return nil, nil, fmt.Errorf("invalid token claims")
|
return nil, nil, fmt.Errorf("invalid token claims")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,8 +133,6 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
|
|||||||
func GetUserFromContext(c *gin.Context) *uuid.UUID {
|
func GetUserFromContext(c *gin.Context) *uuid.UUID {
|
||||||
// 修正点:使用和存入时完全一样的 Key
|
// 修正点:使用和存入时完全一样的 Key
|
||||||
val, exists := c.Get(ContextUserIDKey)
|
val, exists := c.Get(ContextUserIDKey)
|
||||||
fmt.Println("within getFromContext the id is ... ")
|
|
||||||
fmt.Println(val)
|
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,22 +1,33 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||||
|
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DocumentHandler struct {
|
type DocumentHandler struct {
|
||||||
store *store.PostgresStore
|
store *store.PostgresStore
|
||||||
|
messageBus messagebus.MessageBus
|
||||||
|
serverID string
|
||||||
|
logger *zap.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
|
func NewDocumentHandler(s *store.PostgresStore, msgBus messagebus.MessageBus, serverID string, logger *zap.Logger) *DocumentHandler {
|
||||||
return &DocumentHandler{store: s}
|
return &DocumentHandler{
|
||||||
|
store: s,
|
||||||
|
messageBus: msgBus,
|
||||||
|
serverID: serverID,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateDocument creates a new document (requires auth)
|
// CreateDocument creates a new document (requires auth)
|
||||||
@@ -45,8 +56,6 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
|
|||||||
|
|
||||||
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
|
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
|
||||||
userID := auth.GetUserFromContext(c)
|
userID := auth.GetUserFromContext(c)
|
||||||
fmt.Println("Getting userId, which is : ")
|
|
||||||
fmt.Println(userID)
|
|
||||||
if userID == nil {
|
if userID == nil {
|
||||||
respondUnauthorized(c, "Authentication required to list documents")
|
respondUnauthorized(c, "Authentication required to list documents")
|
||||||
return
|
return
|
||||||
@@ -113,6 +122,13 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
userID := auth.GetUserFromContext(c)
|
userID := auth.GetUserFromContext(c)
|
||||||
|
shareToken := c.Query("share")
|
||||||
|
|
||||||
|
doc, err := h.store.GetDocument(id)
|
||||||
|
if err != nil {
|
||||||
|
respondNotFound(c, "document")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Check permission if authenticated
|
// Check permission if authenticated
|
||||||
if userID != nil {
|
if userID != nil {
|
||||||
@@ -125,12 +141,22 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
|||||||
respondForbidden(c, "Access denied")
|
respondForbidden(c, "Access denied")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
} else {
|
||||||
|
// Unauthenticated: require valid share token or public doc
|
||||||
doc, err := h.store.GetDocument(id)
|
if shareToken != "" {
|
||||||
if err != nil {
|
valid, err := h.store.ValidateShareToken(c.Request.Context(), id, shareToken)
|
||||||
respondNotFound(c, "document")
|
if err != nil {
|
||||||
return
|
respondInternalError(c, "Failed to validate share token", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
respondForbidden(c, "Invalid or expired share token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else if !doc.Is_Public {
|
||||||
|
respondForbidden(c, "This document is not public. Please sign in to access.")
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return empty byte slice if state is nil (new document)
|
// Return empty byte slice if state is nil (new document)
|
||||||
@@ -191,6 +217,16 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if streamID, seq, ok := h.addSnapshotMarker(c.Request.Context(), id); ok {
|
||||||
|
if err := h.store.UpsertStreamCheckpoint(c.Request.Context(), id, streamID, seq); err != nil {
|
||||||
|
if h.logger != nil {
|
||||||
|
h.logger.Warn("Failed to upsert stream checkpoint after snapshot",
|
||||||
|
zap.String("document_id", id.String()),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
|
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -234,6 +270,43 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
|
|||||||
c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"})
|
c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *DocumentHandler) addSnapshotMarker(ctx context.Context, documentID uuid.UUID) (string, int64, bool) {
|
||||||
|
if h.messageBus == nil {
|
||||||
|
return "", 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
streamKey := "stream:" + documentID.String()
|
||||||
|
seqKey := "seq:" + documentID.String()
|
||||||
|
|
||||||
|
seq, err := h.messageBus.Incr(ctx, seqKey)
|
||||||
|
if err != nil {
|
||||||
|
if h.logger != nil {
|
||||||
|
h.logger.Warn("Failed to increment snapshot sequence",
|
||||||
|
zap.String("document_id", documentID.String()),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
return "", 0, false
|
||||||
|
}
|
||||||
|
|
||||||
|
values := map[string]interface{}{
|
||||||
|
"type": "snapshot",
|
||||||
|
"server_id": h.serverID,
|
||||||
|
"seq": seq,
|
||||||
|
"timestamp": time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
streamID, err := h.messageBus.XAdd(ctx, streamKey, 10000, true, values)
|
||||||
|
if err != nil {
|
||||||
|
if h.logger != nil {
|
||||||
|
h.logger.Warn("Failed to add snapshot marker to stream",
|
||||||
|
zap.String("stream_key", streamKey),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
return "", 0, false
|
||||||
|
}
|
||||||
|
return streamID, seq, true
|
||||||
|
}
|
||||||
|
|
||||||
// GetDocumentPermission returns the user's permission level for a document
|
// GetDocumentPermission returns the user's permission level for a document
|
||||||
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
|
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
|
||||||
documentID, err := uuid.Parse(c.Param("id"))
|
documentID, err := uuid.Parse(c.Param("id"))
|
||||||
|
|||||||
@@ -7,10 +7,12 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||||
|
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DocumentHandlerSuite tests document CRUD operations
|
// DocumentHandlerSuite tests document CRUD operations
|
||||||
@@ -23,7 +25,7 @@ type DocumentHandlerSuite struct {
|
|||||||
// SetupTest runs before each test
|
// SetupTest runs before each test
|
||||||
func (s *DocumentHandlerSuite) SetupTest() {
|
func (s *DocumentHandlerSuite) SetupTest() {
|
||||||
s.BaseHandlerSuite.SetupTest()
|
s.BaseHandlerSuite.SetupTest()
|
||||||
s.handler = NewDocumentHandler(s.store)
|
s.handler = NewDocumentHandler(s.store, messagebus.NewLocalMessageBus(), "test-server", zap.NewNop())
|
||||||
s.setupRouter()
|
s.setupRouter()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/suite"
|
"github.com/stretchr/testify/suite"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ShareHandlerSuite tests for share handler endpoints
|
// ShareHandlerSuite tests for share handler endpoints
|
||||||
@@ -24,7 +25,7 @@ func (s *ShareHandlerSuite) SetupTest() {
|
|||||||
s.BaseHandlerSuite.SetupTest()
|
s.BaseHandlerSuite.SetupTest()
|
||||||
|
|
||||||
// Create handler and router
|
// Create handler and router
|
||||||
authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret)
|
authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret, zap.NewNop())
|
||||||
s.handler = NewShareHandler(s.store, s.cfg)
|
s.handler = NewShareHandler(s.store, s.cfg)
|
||||||
s.router = gin.New()
|
s.router = gin.New()
|
||||||
|
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
package handlers
|
package handlers
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||||
"github.com/M1ngdaXie/realtime-collab/internal/config"
|
"github.com/M1ngdaXie/realtime-collab/internal/config"
|
||||||
"github.com/M1ngdaXie/realtime-collab/internal/hub"
|
"github.com/M1ngdaXie/realtime-collab/internal/hub"
|
||||||
|
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -19,16 +24,18 @@ import (
|
|||||||
var connectionSem = make(chan struct{}, 200)
|
var connectionSem = make(chan struct{}, 200)
|
||||||
|
|
||||||
type WebSocketHandler struct {
|
type WebSocketHandler struct {
|
||||||
hub *hub.Hub
|
hub *hub.Hub
|
||||||
store store.Store
|
store store.Store
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
|
msgBus messagebus.MessageBus
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config) *WebSocketHandler {
|
func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config, msgBus messagebus.MessageBus) *WebSocketHandler {
|
||||||
return &WebSocketHandler{
|
return &WebSocketHandler{
|
||||||
hub: h,
|
hub: h,
|
||||||
store: s,
|
store: s,
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
|
msgBus: msgBus,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -170,6 +177,105 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
|||||||
// Start goroutines
|
// Start goroutines
|
||||||
go client.WritePump()
|
go client.WritePump()
|
||||||
go client.ReadPump()
|
go client.ReadPump()
|
||||||
|
go wsh.replayBacklog(client, documentID)
|
||||||
|
|
||||||
log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID)
|
log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const maxReplayUpdates = 5000
|
||||||
|
|
||||||
|
func (wsh *WebSocketHandler) replayBacklog(client *hub.Client, documentID uuid.UUID) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
checkpoint, err := wsh.store.GetStreamCheckpoint(ctx, documentID)
|
||||||
|
if err != nil || checkpoint == nil || checkpoint.LastStreamID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
streamKey := "stream:" + documentID.String()
|
||||||
|
var sent int
|
||||||
|
|
||||||
|
// Primary: Redis stream replay
|
||||||
|
if wsh.msgBus != nil {
|
||||||
|
messages, err := wsh.msgBus.XRange(ctx, streamKey, checkpoint.LastStreamID, "+")
|
||||||
|
if err == nil && len(messages) > 0 {
|
||||||
|
for _, msg := range messages {
|
||||||
|
if msg.ID == checkpoint.LastStreamID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if sent >= maxReplayUpdates {
|
||||||
|
log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
msgType := getString(msg.Values["type"])
|
||||||
|
if msgType != "update" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seq := parseInt64(msg.Values["seq"])
|
||||||
|
if seq <= checkpoint.LastSeq {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
payloadB64 := getString(msg.Values["yjs_payload"])
|
||||||
|
payload, err := base64.StdEncoding.DecodeString(payloadB64)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if client.Enqueue(payload) {
|
||||||
|
sent++
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: DB history replay
|
||||||
|
updates, err := wsh.store.ListUpdateHistoryAfterSeq(ctx, documentID, checkpoint.LastSeq, maxReplayUpdates)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for _, upd := range updates {
|
||||||
|
if sent >= maxReplayUpdates {
|
||||||
|
log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if client.Enqueue(upd.Payload) {
|
||||||
|
sent++
|
||||||
|
} else {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getString(value interface{}) string {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
return v
|
||||||
|
case []byte:
|
||||||
|
return string(v)
|
||||||
|
default:
|
||||||
|
return fmt.Sprint(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInt64(value interface{}) int64 {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int64:
|
||||||
|
return v
|
||||||
|
case int:
|
||||||
|
return int64(v)
|
||||||
|
case uint64:
|
||||||
|
return int64(v)
|
||||||
|
case string:
|
||||||
|
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
case []byte:
|
||||||
|
if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package hub
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -37,10 +39,11 @@ type Client struct {
|
|||||||
idsMu sync.Mutex
|
idsMu sync.Mutex
|
||||||
}
|
}
|
||||||
type Room struct {
|
type Room struct {
|
||||||
ID string
|
ID string
|
||||||
clients map[*Client]bool
|
clients map[*Client]bool
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
|
reconnectCount int // Track Redis reconnection attempts for debugging
|
||||||
}
|
}
|
||||||
|
|
||||||
type Hub struct {
|
type Hub struct {
|
||||||
@@ -64,6 +67,10 @@ type Hub struct {
|
|||||||
|
|
||||||
// Bounded worker pool for Redis SetAwareness
|
// Bounded worker pool for Redis SetAwareness
|
||||||
awarenessQueue chan awarenessItem
|
awarenessQueue chan awarenessItem
|
||||||
|
|
||||||
|
// Stream persistence worker pool (P1: Redis Streams durability)
|
||||||
|
streamQueue chan *Message // buffered queue for XADD operations
|
||||||
|
streamDone chan struct{} // close to signal stream workers to exit
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -79,6 +86,13 @@ const (
|
|||||||
|
|
||||||
// awarenessQueueSize is the buffer size for awareness updates.
|
// awarenessQueueSize is the buffer size for awareness updates.
|
||||||
awarenessQueueSize = 4096
|
awarenessQueueSize = 4096
|
||||||
|
|
||||||
|
// streamWorkerCount is the number of fixed goroutines consuming from streamQueue.
|
||||||
|
// 50 workers match publish workers for consistent throughput.
|
||||||
|
streamWorkerCount = 50
|
||||||
|
|
||||||
|
// streamQueueSize is the buffer size for the stream persistence queue.
|
||||||
|
streamQueueSize = 4096
|
||||||
)
|
)
|
||||||
|
|
||||||
type awarenessItem struct {
|
type awarenessItem struct {
|
||||||
@@ -103,11 +117,15 @@ func NewHub(messagebus messagebus.MessageBus, serverID string, logger *zap.Logge
|
|||||||
publishDone: make(chan struct{}),
|
publishDone: make(chan struct{}),
|
||||||
// bounded awareness worker pool
|
// bounded awareness worker pool
|
||||||
awarenessQueue: make(chan awarenessItem, awarenessQueueSize),
|
awarenessQueue: make(chan awarenessItem, awarenessQueueSize),
|
||||||
|
// Stream persistence worker pool
|
||||||
|
streamQueue: make(chan *Message, streamQueueSize),
|
||||||
|
streamDone: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start the fixed worker pool for Redis publishing
|
// Start the fixed worker pool for Redis publishing
|
||||||
h.startPublishWorkers(publishWorkerCount)
|
h.startPublishWorkers(publishWorkerCount)
|
||||||
h.startAwarenessWorkers(awarenessWorkerCount)
|
h.startAwarenessWorkers(awarenessWorkerCount)
|
||||||
|
h.startStreamWorkers(streamWorkerCount)
|
||||||
|
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
@@ -173,6 +191,82 @@ func (h *Hub) startAwarenessWorkers(n int) {
|
|||||||
h.logger.Info("Awareness worker pool started", zap.Int("workers", n))
|
h.logger.Info("Awareness worker pool started", zap.Int("workers", n))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// startStreamWorkers launches n goroutines that consume from streamQueue
|
||||||
|
// and add messages to Redis Streams for durability and replay.
|
||||||
|
func (h *Hub) startStreamWorkers(n int) {
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
go func(workerID int) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-h.streamDone:
|
||||||
|
h.logger.Info("Stream worker exiting", zap.Int("worker_id", workerID))
|
||||||
|
return
|
||||||
|
case msg, ok := <-h.streamQueue:
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.addToStream(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
h.logger.Info("Stream worker pool started", zap.Int("workers", n))
|
||||||
|
}
|
||||||
|
|
||||||
|
// encodeBase64 encodes binary data to base64 string for Redis storage
|
||||||
|
func encodeBase64(data []byte) string {
|
||||||
|
return base64.StdEncoding.EncodeToString(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addToStream adds a message to Redis Streams for durability
|
||||||
|
func (h *Hub) addToStream(msg *Message) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
streamKey := "stream:" + msg.RoomID
|
||||||
|
|
||||||
|
// Get next sequence number atomically
|
||||||
|
seqKey := "seq:" + msg.RoomID
|
||||||
|
seq, err := h.messagebus.Incr(ctx, seqKey)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("Failed to increment sequence",
|
||||||
|
zap.String("room_id", msg.RoomID),
|
||||||
|
zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode payload as base64 (binary-safe storage)
|
||||||
|
payload := encodeBase64(msg.Data)
|
||||||
|
|
||||||
|
// Extract Yjs message type from first byte as numeric string
|
||||||
|
msgType := "0"
|
||||||
|
if len(msg.Data) > 0 {
|
||||||
|
msgType = strconv.Itoa(int(msg.Data[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add entry to Stream with MAXLEN trimming
|
||||||
|
values := map[string]interface{}{
|
||||||
|
"type": "update",
|
||||||
|
"server_id": h.serverID,
|
||||||
|
"yjs_payload": payload,
|
||||||
|
"msg_type": msgType,
|
||||||
|
"seq": seq,
|
||||||
|
"timestamp": time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = h.messagebus.XAdd(ctx, streamKey, 10000, true, values)
|
||||||
|
if err != nil {
|
||||||
|
h.logger.Error("Failed to add to Stream",
|
||||||
|
zap.String("stream_key", streamKey),
|
||||||
|
zap.Int64("seq", seq),
|
||||||
|
zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mark this document as active so the persist worker only processes active streams
|
||||||
|
_ = h.messagebus.ZAdd(ctx, "active-streams", float64(time.Now().Unix()), msg.RoomID)
|
||||||
|
}
|
||||||
|
|
||||||
func (h *Hub) Run() {
|
func (h *Hub) Run() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -471,6 +565,7 @@ func (h *Hub) broadcastMessage(message *Message) {
|
|||||||
// 只有本地客户端发出的消息 (sender != nil) 才推送到 Redis
|
// 只有本地客户端发出的消息 (sender != nil) 才推送到 Redis
|
||||||
// P0 fix: send to bounded worker pool instead of spawning unbounded goroutines
|
// P0 fix: send to bounded worker pool instead of spawning unbounded goroutines
|
||||||
if message.sender != nil && !h.fallbackMode && h.messagebus != nil {
|
if message.sender != nil && !h.fallbackMode && h.messagebus != nil {
|
||||||
|
// 3a. Publish to Pub/Sub (real-time cross-server broadcast)
|
||||||
select {
|
select {
|
||||||
case h.publishQueue <- message:
|
case h.publishQueue <- message:
|
||||||
// Successfully queued for async publish by worker pool
|
// Successfully queued for async publish by worker pool
|
||||||
@@ -479,6 +574,19 @@ func (h *Hub) broadcastMessage(message *Message) {
|
|||||||
h.logger.Warn("Publish queue full, dropping Redis publish",
|
h.logger.Warn("Publish queue full, dropping Redis publish",
|
||||||
zap.String("room_id", message.RoomID))
|
zap.String("room_id", message.RoomID))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 3b. Add to Stream for durability (only Type 0 updates, not Type 1 awareness)
|
||||||
|
// Type 0 = Yjs sync/update messages (document changes)
|
||||||
|
// Type 1 = Yjs awareness messages (cursors, presence) - ephemeral, skip
|
||||||
|
if len(message.Data) > 0 && message.Data[0] == 0 {
|
||||||
|
select {
|
||||||
|
case h.streamQueue <- message:
|
||||||
|
// Successfully queued for async Stream add
|
||||||
|
default:
|
||||||
|
h.logger.Warn("Stream queue full, dropping durability",
|
||||||
|
zap.String("room_id", message.RoomID))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -504,10 +612,28 @@ func (h *Hub) broadcastToLocalClients(room *Room, data []byte, sender *Client) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
func (h *Hub) startRoomMessageForwarding(ctx context.Context, roomID string, msgChan <-chan []byte) {
|
func (h *Hub) startRoomMessageForwarding(ctx context.Context, roomID string, msgChan <-chan []byte) {
|
||||||
h.logger.Info("Starting message forwarding from Redis to room",
|
// Increment and log reconnection count for debugging
|
||||||
zap.String("room_id", roomID),
|
h.mu.RLock()
|
||||||
zap.String("server_id", h.serverID),
|
room, exists := h.rooms[roomID]
|
||||||
)
|
h.mu.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
room.mu.Lock()
|
||||||
|
room.reconnectCount++
|
||||||
|
reconnectCount := room.reconnectCount
|
||||||
|
room.mu.Unlock()
|
||||||
|
|
||||||
|
h.logger.Info("Starting message forwarding from Redis to room",
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
zap.String("server_id", h.serverID),
|
||||||
|
zap.Int("reconnect_count", reconnectCount),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
h.logger.Info("Starting message forwarding from Redis to room",
|
||||||
|
zap.String("room_id", roomID),
|
||||||
|
zap.String("server_id", h.serverID),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -791,12 +917,28 @@ func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string
|
|||||||
UserAvatar: userAvatar,
|
UserAvatar: userAvatar,
|
||||||
Permission: permission,
|
Permission: permission,
|
||||||
Conn: conn,
|
Conn: conn,
|
||||||
send: make(chan []byte, 1024),
|
send: make(chan []byte, 8192),
|
||||||
hub: hub,
|
hub: hub,
|
||||||
roomID: roomID,
|
roomID: roomID,
|
||||||
observedYjsIDs: make(map[uint64]uint64),
|
observedYjsIDs: make(map[uint64]uint64),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Enqueue sends a message to the client send buffer (non-blocking).
|
||||||
|
// Returns false if the buffer is full.
|
||||||
|
func (c *Client) Enqueue(message []byte) bool {
|
||||||
|
select {
|
||||||
|
case c.send <- message:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
if c.hub != nil && c.hub.logger != nil {
|
||||||
|
c.hub.logger.Warn("Client send buffer full during replay",
|
||||||
|
zap.String("client_id", c.ID),
|
||||||
|
zap.String("room_id", c.roomID))
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
func (c *Client) unregister() {
|
func (c *Client) unregister() {
|
||||||
c.unregisterOnce.Do(func() {
|
c.unregisterOnce.Do(func() {
|
||||||
c.hub.Unregister <- c
|
c.hub.Unregister <- c
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func NewLogger(isDevelopment bool) (*zap.Logger, error) {
|
|||||||
// 👇 关键修改:直接拉到 Fatal 级别
|
// 👇 关键修改:直接拉到 Fatal 级别
|
||||||
// 这样 Error, Warn, Info, Debug 全部都会被忽略
|
// 这样 Error, Warn, Info, Debug 全部都会被忽略
|
||||||
// 彻底消除 IO 锁竞争
|
// 彻底消除 IO 锁竞争
|
||||||
config.Level = zap.NewAtomicLevelAt(zapcore.FatalLevel)
|
config.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
|
||||||
|
|
||||||
logger, err := config.Build()
|
logger, err := config.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package messagebus
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MessageBus abstracts message distribution across server instances
|
// MessageBus abstracts message distribution across server instances
|
||||||
@@ -33,6 +34,72 @@ type MessageBus interface {
|
|||||||
|
|
||||||
// Close gracefully shuts down the message bus
|
// Close gracefully shuts down the message bus
|
||||||
Close() error
|
Close() error
|
||||||
|
|
||||||
|
// ========== Redis Streams Operations ==========
|
||||||
|
|
||||||
|
// XAdd adds a new entry to a stream with optional MAXLEN trimming
|
||||||
|
XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error)
|
||||||
|
|
||||||
|
// XReadGroup reads messages from a stream using a consumer group
|
||||||
|
XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error)
|
||||||
|
|
||||||
|
// XAck acknowledges one or more messages from a consumer group
|
||||||
|
XAck(ctx context.Context, stream, group string, ids ...string) (int64, error)
|
||||||
|
|
||||||
|
// XGroupCreate creates a new consumer group for a stream
|
||||||
|
XGroupCreate(ctx context.Context, stream, group, start string) error
|
||||||
|
|
||||||
|
// XGroupCreateMkStream creates a consumer group and the stream if it doesn't exist
|
||||||
|
XGroupCreateMkStream(ctx context.Context, stream, group, start string) error
|
||||||
|
|
||||||
|
// XPending returns pending messages information for a consumer group
|
||||||
|
XPending(ctx context.Context, stream, group string) (*PendingInfo, error)
|
||||||
|
|
||||||
|
// XClaim claims pending messages from a consumer group
|
||||||
|
XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error)
|
||||||
|
|
||||||
|
// XAutoClaim claims pending messages automatically (Redis >= 6.2)
|
||||||
|
// Returns claimed messages and next start ID.
|
||||||
|
XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error)
|
||||||
|
|
||||||
|
// XRange reads a range of messages from a stream
|
||||||
|
XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error)
|
||||||
|
|
||||||
|
// XTrimMinID trims a stream to a minimum ID (time-based retention)
|
||||||
|
XTrimMinID(ctx context.Context, stream, minID string) (int64, error)
|
||||||
|
|
||||||
|
// Incr increments a counter atomically (for sequence numbers)
|
||||||
|
Incr(ctx context.Context, key string) (int64, error)
|
||||||
|
|
||||||
|
// ========== Sorted Set (ZSET) Operations ==========
|
||||||
|
|
||||||
|
// ZAdd adds a member with a score to a sorted set (used for active-stream tracking)
|
||||||
|
ZAdd(ctx context.Context, key string, score float64, member string) error
|
||||||
|
|
||||||
|
// ZRangeByScore returns members with scores between min and max
|
||||||
|
ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error)
|
||||||
|
|
||||||
|
// ZRemRangeByScore removes members with scores between min and max
|
||||||
|
ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error)
|
||||||
|
|
||||||
|
// Distributed lock helpers (used by background workers)
|
||||||
|
AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error)
|
||||||
|
RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error)
|
||||||
|
ReleaseLock(ctx context.Context, key string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamMessage represents a message from a Redis Stream
|
||||||
|
type StreamMessage struct {
|
||||||
|
ID string
|
||||||
|
Values map[string]interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PendingInfo contains information about pending messages in a consumer group
|
||||||
|
type PendingInfo struct {
|
||||||
|
Count int64
|
||||||
|
Lower string
|
||||||
|
Upper string
|
||||||
|
Consumers map[string]int64
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalMessageBus is a no-op implementation for single-server mode
|
// LocalMessageBus is a no-op implementation for single-server mode
|
||||||
@@ -78,3 +145,73 @@ func (l *LocalMessageBus) IsHealthy() bool {
|
|||||||
func (l *LocalMessageBus) Close() error {
|
func (l *LocalMessageBus) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ========== Redis Streams Operations (No-op for local mode) ==========
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) {
|
||||||
|
return "0-0", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) XGroupCreate(ctx context.Context, stream, group, start string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) XPending(ctx context.Context, stream, group string) (*PendingInfo, error) {
|
||||||
|
return &PendingInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) {
|
||||||
|
return nil, "0-0", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) XTrimMinID(ctx context.Context, stream, minID string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) Incr(ctx context.Context, key string) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) ZAdd(ctx context.Context, key string, score float64, member string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalMessageBus) ReleaseLock(ctx context.Context, key string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -88,6 +89,23 @@ func NewRedisMessageBus(redisURL string, serverID string, logger *zap.Logger) (*
|
|||||||
// - Redis will handle stale connections via TCP keepalive
|
// - Redis will handle stale connections via TCP keepalive
|
||||||
opts.ConnMaxLifetime = 1 * time.Hour
|
opts.ConnMaxLifetime = 1 * time.Hour
|
||||||
|
|
||||||
|
// ================================
|
||||||
|
// Socket-Level Timeout Configuration (prevents indefinite hangs)
|
||||||
|
// ================================
|
||||||
|
// Without these, TCP reads/writes block indefinitely when Redis is unresponsive,
|
||||||
|
// causing OS-level timeouts (60-120s) instead of application-level control.
|
||||||
|
|
||||||
|
// DialTimeout: How long to wait for initial connection establishment
|
||||||
|
opts.DialTimeout = 5 * time.Second
|
||||||
|
|
||||||
|
// ReadTimeout: Maximum time for socket read operations
|
||||||
|
// - 30s is appropriate for PubSub (long intervals between messages are normal)
|
||||||
|
// - Prevents indefinite blocking when Redis hangs
|
||||||
|
opts.ReadTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
// WriteTimeout: Maximum time for socket write operations
|
||||||
|
opts.WriteTimeout = 5 * time.Second
|
||||||
|
|
||||||
client := goredis.NewClient(opts)
|
client := goredis.NewClient(opts)
|
||||||
|
|
||||||
// ================================
|
// ================================
|
||||||
@@ -215,12 +233,15 @@ func (r *RedisMessageBus) readLoop(ctx context.Context, roomID string, sub *subs
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
r.logger.Warn("PubSub initial subscription failed, retrying with backoff",
|
||||||
|
zap.String("roomID", roomID),
|
||||||
|
zap.Error(err),
|
||||||
|
zap.Duration("backoff", backoff),
|
||||||
|
)
|
||||||
time.Sleep(backoff)
|
time.Sleep(backoff)
|
||||||
if backoff < maxBackoff {
|
backoff = backoff * 2
|
||||||
backoff *= 2
|
if backoff > maxBackoff {
|
||||||
if backoff > maxBackoff {
|
backoff = maxBackoff
|
||||||
backoff = maxBackoff
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -242,12 +263,15 @@ func (r *RedisMessageBus) readLoop(ctx context.Context, roomID string, sub *subs
|
|||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
r.logger.Warn("PubSub receive failed, retrying with backoff",
|
||||||
|
zap.String("roomID", roomID),
|
||||||
|
zap.Error(err),
|
||||||
|
zap.Duration("backoff", backoff),
|
||||||
|
)
|
||||||
time.Sleep(backoff)
|
time.Sleep(backoff)
|
||||||
if backoff < maxBackoff {
|
backoff = backoff * 2
|
||||||
backoff *= 2
|
if backoff > maxBackoff {
|
||||||
if backoff > maxBackoff {
|
backoff = maxBackoff
|
||||||
backoff = maxBackoff
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -261,12 +285,15 @@ func (r *RedisMessageBus) receiveOnce(ctx context.Context, roomID string, pubsub
|
|||||||
|
|
||||||
msg, err := pubsub.ReceiveTimeout(ctx, 5*time.Second)
|
msg, err := pubsub.ReceiveTimeout(ctx, 5*time.Second)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
if ctx.Err() != nil {
|
||||||
return err
|
return ctx.Err()
|
||||||
}
|
}
|
||||||
if errors.Is(err, goredis.Nil) {
|
if errors.Is(err, goredis.Nil) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if isTimeoutErr(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
r.logger.Warn("pubsub receive error, closing subscription",
|
r.logger.Warn("pubsub receive error, closing subscription",
|
||||||
zap.String("roomID", roomID),
|
zap.String("roomID", roomID),
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
@@ -308,6 +335,17 @@ func (r *RedisMessageBus) receiveOnce(ctx context.Context, roomID string, pubsub
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isTimeoutErr(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
var netErr net.Error
|
||||||
|
return errors.As(err, &netErr) && netErr.Timeout()
|
||||||
|
}
|
||||||
|
|
||||||
// Unsubscribe stops listening to a room
|
// Unsubscribe stops listening to a room
|
||||||
func (r *RedisMessageBus) Unsubscribe(ctx context.Context, roomID string) error {
|
func (r *RedisMessageBus) Unsubscribe(ctx context.Context, roomID string) error {
|
||||||
r.subMu.Lock()
|
r.subMu.Lock()
|
||||||
@@ -430,7 +468,7 @@ func (r *RedisMessageBus) DeleteAwareness(ctx context.Context, roomID string, cl
|
|||||||
|
|
||||||
// IsHealthy checks Redis connectivity
|
// IsHealthy checks Redis connectivity
|
||||||
func (r *RedisMessageBus) IsHealthy() bool {
|
func (r *RedisMessageBus) IsHealthy() bool {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
// 只有 Ping 成功且没有报错,才认为服务是健康的
|
// 只有 Ping 成功且没有报错,才认为服务是健康的
|
||||||
@@ -516,3 +554,223 @@ func (r *RedisMessageBus) ClearAllAwareness(ctx context.Context, roomID string)
|
|||||||
// 直接使用 Del 命令删除整个 Key
|
// 直接使用 Del 命令删除整个 Key
|
||||||
return r.client.Del(ctx, key).Err()
|
return r.client.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ========== Redis Streams Operations ==========
|
||||||
|
|
||||||
|
// XAdd adds a new entry to a stream with optional MAXLEN trimming
|
||||||
|
func (r *RedisMessageBus) XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) {
|
||||||
|
result := r.client.XAdd(ctx, &goredis.XAddArgs{
|
||||||
|
Stream: stream,
|
||||||
|
MaxLen: maxLen,
|
||||||
|
Approx: approx,
|
||||||
|
Values: values,
|
||||||
|
})
|
||||||
|
return result.Val(), result.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// XReadGroup reads messages from a stream using a consumer group
|
||||||
|
func (r *RedisMessageBus) XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) {
|
||||||
|
result := r.client.XReadGroup(ctx, &goredis.XReadGroupArgs{
|
||||||
|
Group: group,
|
||||||
|
Consumer: consumer,
|
||||||
|
Streams: streams,
|
||||||
|
Count: count,
|
||||||
|
Block: block,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := result.Err(); err != nil {
|
||||||
|
// Timeout is not an error, just no new messages
|
||||||
|
if err == goredis.Nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert go-redis XStream to our StreamMessage format
|
||||||
|
var messages []StreamMessage
|
||||||
|
for _, stream := range result.Val() {
|
||||||
|
for _, msg := range stream.Messages {
|
||||||
|
messages = append(messages, StreamMessage{
|
||||||
|
ID: msg.ID,
|
||||||
|
Values: msg.Values,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// XAck acknowledges one or more messages from a consumer group
|
||||||
|
func (r *RedisMessageBus) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) {
|
||||||
|
result := r.client.XAck(ctx, stream, group, ids...)
|
||||||
|
return result.Val(), result.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// XGroupCreate creates a new consumer group for a stream
|
||||||
|
func (r *RedisMessageBus) XGroupCreate(ctx context.Context, stream, group, start string) error {
|
||||||
|
return r.client.XGroupCreate(ctx, stream, group, start).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// XGroupCreateMkStream creates a consumer group and the stream if it doesn't exist
|
||||||
|
func (r *RedisMessageBus) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error {
|
||||||
|
return r.client.XGroupCreateMkStream(ctx, stream, group, start).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// XPending returns pending messages information for a consumer group
|
||||||
|
func (r *RedisMessageBus) XPending(ctx context.Context, stream, group string) (*PendingInfo, error) {
|
||||||
|
result := r.client.XPending(ctx, stream, group)
|
||||||
|
if err := result.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
pending := result.Val()
|
||||||
|
consumers := make(map[string]int64)
|
||||||
|
for name, count := range pending.Consumers {
|
||||||
|
consumers[name] = count
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PendingInfo{
|
||||||
|
Count: pending.Count,
|
||||||
|
Lower: pending.Lower,
|
||||||
|
Upper: pending.Higher, // go-redis uses "Higher" instead of "Upper"
|
||||||
|
Consumers: consumers,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// XClaim claims pending messages from a consumer group
|
||||||
|
func (r *RedisMessageBus) XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) {
|
||||||
|
result := r.client.XClaim(ctx, &goredis.XClaimArgs{
|
||||||
|
Stream: stream,
|
||||||
|
Group: group,
|
||||||
|
Consumer: consumer,
|
||||||
|
MinIdle: minIdleTime,
|
||||||
|
Messages: ids,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := result.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert go-redis XMessage to our StreamMessage format
|
||||||
|
var messages []StreamMessage
|
||||||
|
for _, msg := range result.Val() {
|
||||||
|
messages = append(messages, StreamMessage{
|
||||||
|
ID: msg.ID,
|
||||||
|
Values: msg.Values,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// XAutoClaim claims pending messages automatically (Redis >= 6.2)
|
||||||
|
func (r *RedisMessageBus) XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) {
|
||||||
|
result := r.client.XAutoClaim(ctx, &goredis.XAutoClaimArgs{
|
||||||
|
Stream: stream,
|
||||||
|
Group: group,
|
||||||
|
Consumer: consumer,
|
||||||
|
MinIdle: minIdleTime,
|
||||||
|
Start: start,
|
||||||
|
Count: count,
|
||||||
|
})
|
||||||
|
msgs, nextStart, err := result.Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil, "", err
|
||||||
|
}
|
||||||
|
messages := make([]StreamMessage, 0, len(msgs))
|
||||||
|
for _, msg := range msgs {
|
||||||
|
messages = append(messages, StreamMessage{
|
||||||
|
ID: msg.ID,
|
||||||
|
Values: msg.Values,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return messages, nextStart, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// XRange reads a range of messages from a stream
|
||||||
|
func (r *RedisMessageBus) XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) {
|
||||||
|
result := r.client.XRange(ctx, stream, start, end)
|
||||||
|
if err := result.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert go-redis XMessage to our StreamMessage format
|
||||||
|
var messages []StreamMessage
|
||||||
|
for _, msg := range result.Val() {
|
||||||
|
messages = append(messages, StreamMessage{
|
||||||
|
ID: msg.ID,
|
||||||
|
Values: msg.Values,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// XTrimMinID trims a stream to a minimum ID (time-based retention)
|
||||||
|
func (r *RedisMessageBus) XTrimMinID(ctx context.Context, stream, minID string) (int64, error) {
|
||||||
|
// Use XTRIM with MINID and approximation (~) for efficiency
|
||||||
|
// LIMIT clause prevents blocking Redis during large trims
|
||||||
|
result := r.client.Do(ctx, "XTRIM", stream, "MINID", "~", minID, "LIMIT", 1000)
|
||||||
|
if err := result.Err(); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Result is the number of entries removed
|
||||||
|
trimmed, err := result.Int64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return trimmed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ========== Sorted Set (ZSET) Operations ==========
|
||||||
|
|
||||||
|
// ZAdd adds a member with a score to a sorted set
|
||||||
|
func (r *RedisMessageBus) ZAdd(ctx context.Context, key string, score float64, member string) error {
|
||||||
|
return r.client.ZAdd(ctx, key, goredis.Z{Score: score, Member: member}).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRangeByScore returns members with scores between min and max
|
||||||
|
func (r *RedisMessageBus) ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) {
|
||||||
|
return r.client.ZRangeByScore(ctx, key, &goredis.ZRangeBy{
|
||||||
|
Min: strconv.FormatFloat(min, 'f', -1, 64),
|
||||||
|
Max: strconv.FormatFloat(max, 'f', -1, 64),
|
||||||
|
}).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ZRemRangeByScore removes members with scores between min and max
|
||||||
|
func (r *RedisMessageBus) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) {
|
||||||
|
return r.client.ZRemRangeByScore(ctx, key,
|
||||||
|
strconv.FormatFloat(min, 'f', -1, 64),
|
||||||
|
strconv.FormatFloat(max, 'f', -1, 64),
|
||||||
|
).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Incr increments a counter atomically (for sequence numbers)
|
||||||
|
func (r *RedisMessageBus) Incr(ctx context.Context, key string) (int64, error) {
|
||||||
|
result := r.client.Incr(ctx, key)
|
||||||
|
return result.Val(), result.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// AcquireLock attempts to acquire a distributed lock with TTL
|
||||||
|
func (r *RedisMessageBus) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
|
||||||
|
return r.client.SetNX(ctx, key, r.serverID, ttl).Result()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshLock extends the TTL on an existing lock
|
||||||
|
func (r *RedisMessageBus) RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
|
||||||
|
result := r.client.SetArgs(ctx, key, r.serverID, goredis.SetArgs{
|
||||||
|
Mode: "XX",
|
||||||
|
TTL: ttl,
|
||||||
|
})
|
||||||
|
if err := result.Err(); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return result.Val() == "OK", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReleaseLock releases a distributed lock
|
||||||
|
func (r *RedisMessageBus) ReleaseLock(ctx context.Context, key string) error {
|
||||||
|
return r.client.Del(ctx, key).Err()
|
||||||
|
}
|
||||||
|
|||||||
15
backend/internal/models/stream_checkpoint.go
Normal file
15
backend/internal/models/stream_checkpoint.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package models
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StreamCheckpoint tracks the last processed Redis Stream entry per document
|
||||||
|
type StreamCheckpoint struct {
|
||||||
|
DocumentID uuid.UUID `json:"document_id"`
|
||||||
|
LastStreamID string `json:"last_stream_id"`
|
||||||
|
LastSeq int64 `json:"last_seq"`
|
||||||
|
UpdatedAt time.Time `json:"updated_at"`
|
||||||
|
}
|
||||||
@@ -53,6 +53,15 @@ type Store interface {
|
|||||||
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
|
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
|
||||||
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
|
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
|
||||||
|
|
||||||
|
// Stream checkpoint operations
|
||||||
|
UpsertStreamCheckpoint(ctx context.Context, documentID uuid.UUID, streamID string, seq int64) error
|
||||||
|
GetStreamCheckpoint(ctx context.Context, documentID uuid.UUID) (*models.StreamCheckpoint, error)
|
||||||
|
|
||||||
|
// Update history (WAL) operations
|
||||||
|
InsertUpdateHistoryBatch(ctx context.Context, entries []UpdateHistoryEntry) error
|
||||||
|
ListUpdateHistoryAfterSeq(ctx context.Context, documentID uuid.UUID, afterSeq int64, limit int) ([]UpdateHistoryEntry, error)
|
||||||
|
DeleteUpdateHistoryUpToSeq(ctx context.Context, documentID uuid.UUID, maxSeq int64) error
|
||||||
|
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
46
backend/internal/store/stream_checkpoint.go
Normal file
46
backend/internal/store/stream_checkpoint.go
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpsertStreamCheckpoint creates or updates the stream checkpoint for a document
|
||||||
|
func (s *PostgresStore) UpsertStreamCheckpoint(ctx context.Context, documentID uuid.UUID, streamID string, seq int64) error {
|
||||||
|
query := `
|
||||||
|
INSERT INTO stream_checkpoints (document_id, last_stream_id, last_seq, updated_at)
|
||||||
|
VALUES ($1, $2, $3, NOW())
|
||||||
|
ON CONFLICT (document_id)
|
||||||
|
DO UPDATE SET last_stream_id = EXCLUDED.last_stream_id,
|
||||||
|
last_seq = EXCLUDED.last_seq,
|
||||||
|
updated_at = NOW()
|
||||||
|
`
|
||||||
|
|
||||||
|
if _, err := s.db.ExecContext(ctx, query, documentID, streamID, seq); err != nil {
|
||||||
|
return fmt.Errorf("failed to upsert stream checkpoint: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStreamCheckpoint retrieves the stream checkpoint for a document
|
||||||
|
func (s *PostgresStore) GetStreamCheckpoint(ctx context.Context, documentID uuid.UUID) (*models.StreamCheckpoint, error) {
|
||||||
|
query := `
|
||||||
|
SELECT document_id, last_stream_id, last_seq, updated_at
|
||||||
|
FROM stream_checkpoints
|
||||||
|
WHERE document_id = $1
|
||||||
|
`
|
||||||
|
|
||||||
|
var checkpoint models.StreamCheckpoint
|
||||||
|
if err := s.db.QueryRowContext(ctx, query, documentID).Scan(
|
||||||
|
&checkpoint.DocumentID,
|
||||||
|
&checkpoint.LastStreamID,
|
||||||
|
&checkpoint.LastSeq,
|
||||||
|
&checkpoint.UpdatedAt,
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get stream checkpoint: %w", err)
|
||||||
|
}
|
||||||
|
return &checkpoint, nil
|
||||||
|
}
|
||||||
@@ -71,10 +71,14 @@ func SetupTestDB(t *testing.T) (*PostgresStore, func()) {
|
|||||||
// Run migrations
|
// Run migrations
|
||||||
scriptsDir := filepath.Join("..", "..", "scripts")
|
scriptsDir := filepath.Join("..", "..", "scripts")
|
||||||
migrations := []string{
|
migrations := []string{
|
||||||
"init.sql",
|
"000_extensions.sql",
|
||||||
"001_add_users_and_sessions.sql",
|
"001_init_schema.sql",
|
||||||
"002_add_document_shares.sql",
|
"002_add_users_and_sessions.sql",
|
||||||
"003_add_public_sharing.sql",
|
"003_add_document_shares.sql",
|
||||||
|
"004_add_public_sharing.sql",
|
||||||
|
"005_add_share_link_permission.sql",
|
||||||
|
"010_add_stream_checkpoints.sql",
|
||||||
|
"011_add_update_history.sql",
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, migration := range migrations {
|
for _, migration := range migrations {
|
||||||
@@ -107,6 +111,8 @@ func SetupTestDB(t *testing.T) (*PostgresStore, func()) {
|
|||||||
func TruncateAllTables(ctx context.Context, store *PostgresStore) error {
|
func TruncateAllTables(ctx context.Context, store *PostgresStore) error {
|
||||||
tables := []string{
|
tables := []string{
|
||||||
"document_updates",
|
"document_updates",
|
||||||
|
"document_update_history",
|
||||||
|
"stream_checkpoints",
|
||||||
"document_shares",
|
"document_shares",
|
||||||
"sessions",
|
"sessions",
|
||||||
"documents",
|
"documents",
|
||||||
|
|||||||
115
backend/internal/store/update_history.go
Normal file
115
backend/internal/store/update_history.go
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
package store
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// UpdateHistoryEntry represents a persisted update from Redis Streams
|
||||||
|
// used for recovery and replay.
|
||||||
|
type UpdateHistoryEntry struct {
|
||||||
|
DocumentID uuid.UUID
|
||||||
|
StreamID string
|
||||||
|
Seq int64
|
||||||
|
Payload []byte
|
||||||
|
MsgType string
|
||||||
|
ServerID string
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertUpdateHistoryBatch inserts update history entries in a single batch.
|
||||||
|
// Uses ON CONFLICT DO NOTHING to make inserts idempotent.
|
||||||
|
func (s *PostgresStore) InsertUpdateHistoryBatch(ctx context.Context, entries []UpdateHistoryEntry) error {
|
||||||
|
if len(entries) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var sb strings.Builder
|
||||||
|
sb.WriteString("INSERT INTO document_update_history (document_id, stream_id, seq, payload, msg_type, server_id, created_at) VALUES ")
|
||||||
|
|
||||||
|
args := make([]interface{}, 0, len(entries)*7)
|
||||||
|
for i, e := range entries {
|
||||||
|
if i > 0 {
|
||||||
|
sb.WriteString(",")
|
||||||
|
}
|
||||||
|
base := i*7 + 1
|
||||||
|
sb.WriteString(fmt.Sprintf("($%d,$%d,$%d,$%d,$%d,$%d,$%d)", base, base+1, base+2, base+3, base+4, base+5, base+6))
|
||||||
|
msgType := sanitizeTextForDB(e.MsgType)
|
||||||
|
serverID := sanitizeTextForDB(e.ServerID)
|
||||||
|
args = append(args, e.DocumentID, e.StreamID, e.Seq, e.Payload, nullIfEmpty(msgType), nullIfEmpty(serverID), e.CreatedAt)
|
||||||
|
}
|
||||||
|
// Idempotent insert
|
||||||
|
sb.WriteString(" ON CONFLICT (document_id, stream_id) DO NOTHING")
|
||||||
|
|
||||||
|
if _, err := s.db.ExecContext(ctx, sb.String(), args...); err != nil {
|
||||||
|
return fmt.Errorf("failed to insert update history batch: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListUpdateHistoryAfterSeq returns updates with seq greater than afterSeq, ordered by seq.
|
||||||
|
func (s *PostgresStore) ListUpdateHistoryAfterSeq(ctx context.Context, documentID uuid.UUID, afterSeq int64, limit int) ([]UpdateHistoryEntry, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 1000
|
||||||
|
}
|
||||||
|
query := `
|
||||||
|
SELECT document_id, stream_id, seq, payload, COALESCE(msg_type, ''), COALESCE(server_id, ''), created_at
|
||||||
|
FROM document_update_history
|
||||||
|
WHERE document_id = $1 AND seq > $2
|
||||||
|
ORDER BY seq ASC
|
||||||
|
LIMIT $3
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := s.db.QueryContext(ctx, query, documentID, afterSeq, limit)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list update history: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var results []UpdateHistoryEntry
|
||||||
|
for rows.Next() {
|
||||||
|
var e UpdateHistoryEntry
|
||||||
|
if err := rows.Scan(&e.DocumentID, &e.StreamID, &e.Seq, &e.Payload, &e.MsgType, &e.ServerID, &e.CreatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to scan update history: %w", err)
|
||||||
|
}
|
||||||
|
results = append(results, e)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteUpdateHistoryUpToSeq deletes updates with seq <= maxSeq for a document.
|
||||||
|
func (s *PostgresStore) DeleteUpdateHistoryUpToSeq(ctx context.Context, documentID uuid.UUID, maxSeq int64) error {
|
||||||
|
query := `
|
||||||
|
DELETE FROM document_update_history
|
||||||
|
WHERE document_id = $1 AND seq <= $2
|
||||||
|
`
|
||||||
|
if _, err := s.db.ExecContext(ctx, query, documentID, maxSeq); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete update history: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func nullIfEmpty(s string) interface{} {
|
||||||
|
if s == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeTextForDB(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.IndexByte(s, 0) >= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if !utf8.ValidString(s) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
320
backend/internal/workers/update_persist_worker.go
Normal file
320
backend/internal/workers/update_persist_worker.go
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
package workers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"runtime/debug"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unicode/utf8"
|
||||||
|
|
||||||
|
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||||
|
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
updatePersistGroupName = "update-persist-worker"
|
||||||
|
updatePersistLockKey = "lock:update-persist-worker"
|
||||||
|
updatePersistLockTTL = 30 * time.Second
|
||||||
|
updatePersistTick = 2 * time.Second
|
||||||
|
updateReadCount = 200
|
||||||
|
updateReadBlock = -1 // negative → go-redis omits BLOCK clause → non-blocking
|
||||||
|
updateBatchSize = 500
|
||||||
|
updateSafeSeqLag = int64(1000)
|
||||||
|
updateAutoClaimIdle = 30 * time.Second
|
||||||
|
updateHeartbeatEvery = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// StartUpdatePersistWorker persists Redis Stream updates into Postgres for recovery.
|
||||||
|
func StartUpdatePersistWorker(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) {
|
||||||
|
if msgBus == nil || dbStore == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
func() {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
logWorker(logger, "Update persist worker panic",
|
||||||
|
zap.Any("panic", r),
|
||||||
|
zap.ByteString("stack", debug.Stack()))
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
acquired, err := msgBus.AcquireLock(ctx, updatePersistLockKey, updatePersistLockTTL)
|
||||||
|
if err != nil {
|
||||||
|
logWorker(logger, "Failed to acquire update persist worker lock", zap.Error(err))
|
||||||
|
time.Sleep(updatePersistTick)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !acquired {
|
||||||
|
time.Sleep(updatePersistTick)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logWorker(logger, "Update persist worker lock acquired", zap.String("server_id", serverID))
|
||||||
|
runUpdatePersistWorker(ctx, msgBus, dbStore, logger, serverID)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
// If the worker exited (including panic), pause briefly before retry.
|
||||||
|
time.Sleep(updatePersistTick)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func runUpdatePersistWorker(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) {
|
||||||
|
ticker := time.NewTicker(updatePersistTick)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
refreshTicker := time.NewTicker(updatePersistLockTTL / 2)
|
||||||
|
defer refreshTicker.Stop()
|
||||||
|
|
||||||
|
heartbeatTicker := time.NewTicker(updateHeartbeatEvery)
|
||||||
|
defer heartbeatTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = msgBus.ReleaseLock(ctx, updatePersistLockKey)
|
||||||
|
return
|
||||||
|
case <-refreshTicker.C:
|
||||||
|
ok, err := msgBus.RefreshLock(ctx, updatePersistLockKey, updatePersistLockTTL)
|
||||||
|
if err != nil || !ok {
|
||||||
|
logWorker(logger, "Update persist worker lock lost", zap.Error(err))
|
||||||
|
_ = msgBus.ReleaseLock(ctx, updatePersistLockKey)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case <-heartbeatTicker.C:
|
||||||
|
logWorker(logger, "Update persist worker heartbeat", zap.String("server_id", serverID))
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := processUpdatePersistence(ctx, msgBus, dbStore, logger, serverID); err != nil {
|
||||||
|
logWorker(logger, "Update persist worker tick failed", zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func processUpdatePersistence(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) error {
|
||||||
|
// Only process documents with recent stream activity (active in the last 60 seconds)
|
||||||
|
cutoff := float64(time.Now().Add(-60 * time.Second).Unix())
|
||||||
|
activeDocIDs, err := msgBus.ZRangeByScore(ctx, "active-streams", cutoff, float64(time.Now().Unix()))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get active streams: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prune stale entries older than 5 minutes (best-effort cleanup)
|
||||||
|
stale := float64(time.Now().Add(-5 * time.Minute).Unix())
|
||||||
|
if _, err := msgBus.ZRemRangeByScore(ctx, "active-streams", 0, stale); err != nil {
|
||||||
|
logWorker(logger, "Failed to prune stale active-streams entries", zap.Error(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, docIDStr := range activeDocIDs {
|
||||||
|
docID, err := uuid.Parse(docIDStr)
|
||||||
|
if err != nil {
|
||||||
|
logWorker(logger, "Invalid document ID in active-streams", zap.String("doc_id", docIDStr))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
streamKey := "stream:" + docIDStr
|
||||||
|
if err := ensureConsumerGroup(ctx, msgBus, streamKey, updatePersistGroupName); err != nil {
|
||||||
|
logWorker(logger, "Failed to ensure update persist consumer group", zap.String("stream", streamKey), zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var ackIDs []string
|
||||||
|
docEntries := make([]store.UpdateHistoryEntry, 0, updateBatchSize)
|
||||||
|
|
||||||
|
// First, try to claim idle pending messages (e.g., from previous crashes)
|
||||||
|
claimed, _, err := msgBus.XAutoClaim(ctx, streamKey, updatePersistGroupName, serverID, updateAutoClaimIdle, "0-0", updateReadCount)
|
||||||
|
if err != nil {
|
||||||
|
logWorker(logger, "XAutoClaim failed", zap.String("stream", streamKey), zap.Error(err))
|
||||||
|
} else if len(claimed) > 0 {
|
||||||
|
collectStreamMessages(ctx, msgBus, dbStore, logger, docID, streamKey, claimed, &docEntries, &ackIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
messages, err := msgBus.XReadGroup(ctx, updatePersistGroupName, serverID, []string{streamKey, ">"}, updateReadCount, updateReadBlock)
|
||||||
|
if err != nil {
|
||||||
|
logWorker(logger, "XReadGroup failed", zap.String("stream", streamKey), zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if len(messages) > 0 {
|
||||||
|
collectStreamMessages(ctx, msgBus, dbStore, logger, docID, streamKey, messages, &docEntries, &ackIDs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(docEntries) > 0 {
|
||||||
|
if err := dbStore.InsertUpdateHistoryBatch(ctx, docEntries); err != nil {
|
||||||
|
logWorker(logger, "Failed to insert update history batch", zap.Error(err))
|
||||||
|
// Skip ACK to retry on next tick
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(ackIDs) > 0 {
|
||||||
|
if _, err := msgBus.XAck(ctx, streamKey, updatePersistGroupName, ackIDs...); err != nil {
|
||||||
|
logWorker(logger, "XAck failed", zap.String("stream", streamKey), zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectStreamMessages(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, documentID uuid.UUID, streamKey string, messages []messagebus.StreamMessage, docEntries *[]store.UpdateHistoryEntry, ackIDs *[]string) {
|
||||||
|
for _, msg := range messages {
|
||||||
|
msgType := getString(msg.Values["type"])
|
||||||
|
switch msgType {
|
||||||
|
case "update":
|
||||||
|
payloadB64 := getString(msg.Values["yjs_payload"])
|
||||||
|
payload, err := base64.StdEncoding.DecodeString(payloadB64)
|
||||||
|
if err != nil {
|
||||||
|
logWorker(logger, "Failed to decode update payload",
|
||||||
|
zap.String("stream", streamKey),
|
||||||
|
zap.String("stream_id", msg.ID),
|
||||||
|
zap.Error(err))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seq := parseInt64(msg.Values["seq"])
|
||||||
|
msgType := normalizeMsgType(msg.Values["msg_type"])
|
||||||
|
serverID := sanitizeText(getString(msg.Values["server_id"]))
|
||||||
|
entry := store.UpdateHistoryEntry{
|
||||||
|
DocumentID: documentID,
|
||||||
|
StreamID: msg.ID,
|
||||||
|
Seq: seq,
|
||||||
|
Payload: payload,
|
||||||
|
MsgType: msgType,
|
||||||
|
ServerID: serverID,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
*docEntries = append(*docEntries, entry)
|
||||||
|
case "snapshot":
|
||||||
|
seq := parseInt64(msg.Values["seq"])
|
||||||
|
if seq > 0 {
|
||||||
|
if err := dbStore.UpsertStreamCheckpoint(ctx, documentID, msg.ID, seq); err != nil {
|
||||||
|
logWorker(logger, "Failed to upsert stream checkpoint from snapshot marker",
|
||||||
|
zap.String("document_id", documentID.String()),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
// Retention: prune DB history based on checkpoint (best-effort)
|
||||||
|
maxSeq := seq - updateSafeSeqLag
|
||||||
|
if maxSeq > 0 {
|
||||||
|
if err := dbStore.DeleteUpdateHistoryUpToSeq(ctx, documentID, maxSeq); err != nil {
|
||||||
|
logWorker(logger, "Failed to prune update history",
|
||||||
|
zap.String("document_id", documentID.String()),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Trim Redis stream to avoid unbounded growth (best-effort)
|
||||||
|
if _, err := msgBus.XTrimMinID(ctx, streamKey, msg.ID); err != nil {
|
||||||
|
logWorker(logger, "Failed to trim Redis stream",
|
||||||
|
zap.String("stream", streamKey),
|
||||||
|
zap.Error(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
*ackIDs = append(*ackIDs, msg.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureConsumerGroup(ctx context.Context, msgBus messagebus.MessageBus, streamKey, group string) error {
|
||||||
|
if err := msgBus.XGroupCreateMkStream(ctx, streamKey, group, "0-0"); err != nil {
|
||||||
|
if !isBusyGroup(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isBusyGroup(err error) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.Contains(err.Error(), "BUSYGROUP")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getString(value interface{}) string {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
return v
|
||||||
|
case []byte:
|
||||||
|
return string(v)
|
||||||
|
default:
|
||||||
|
return fmt.Sprint(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseInt64(value interface{}) int64 {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case int64:
|
||||||
|
return v
|
||||||
|
case int:
|
||||||
|
return int64(v)
|
||||||
|
case uint64:
|
||||||
|
return int64(v)
|
||||||
|
case string:
|
||||||
|
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
case []byte:
|
||||||
|
if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeText(s string) string {
|
||||||
|
if s == "" {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
if strings.IndexByte(s, 0) >= 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if !utf8.ValidString(s) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeMsgType(value interface{}) string {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case string:
|
||||||
|
if v == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(v) == 1 {
|
||||||
|
return strconv.Itoa(int(v[0]))
|
||||||
|
}
|
||||||
|
return sanitizeText(v)
|
||||||
|
case []byte:
|
||||||
|
if len(v) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if len(v) == 1 {
|
||||||
|
return strconv.Itoa(int(v[0]))
|
||||||
|
}
|
||||||
|
return sanitizeText(string(v))
|
||||||
|
default:
|
||||||
|
return sanitizeText(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func logWorker(logger *zap.Logger, msg string, fields ...zap.Field) {
|
||||||
|
if logger == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger.Info(msg, fields...)
|
||||||
|
}
|
||||||
12
backend/scripts/010_add_stream_checkpoints.sql
Normal file
12
backend/scripts/010_add_stream_checkpoints.sql
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
-- Migration: Add stream checkpoints table for Redis Streams durability
|
||||||
|
-- This table tracks last processed stream position per document
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS stream_checkpoints (
|
||||||
|
document_id UUID PRIMARY KEY REFERENCES documents(id) ON DELETE CASCADE,
|
||||||
|
last_stream_id TEXT NOT NULL,
|
||||||
|
last_seq BIGINT NOT NULL DEFAULT 0,
|
||||||
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_stream_checkpoints_updated_at
|
||||||
|
ON stream_checkpoints(updated_at DESC);
|
||||||
22
backend/scripts/011_add_update_history.sql
Normal file
22
backend/scripts/011_add_update_history.sql
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
-- Migration: Add update history table for Redis Stream WAL
|
||||||
|
-- This table stores per-update payloads for recovery and replay
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS document_update_history (
|
||||||
|
id BIGSERIAL PRIMARY KEY,
|
||||||
|
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||||
|
stream_id TEXT NOT NULL,
|
||||||
|
seq BIGINT NOT NULL,
|
||||||
|
payload BYTEA NOT NULL,
|
||||||
|
msg_type TEXT,
|
||||||
|
server_id TEXT,
|
||||||
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_stream_id
|
||||||
|
ON document_update_history(document_id, stream_id);
|
||||||
|
|
||||||
|
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_seq
|
||||||
|
ON document_update_history(document_id, seq);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_update_history_document_seq
|
||||||
|
ON document_update_history(document_id, seq);
|
||||||
@@ -24,8 +24,11 @@ services:
|
|||||||
redis:
|
redis:
|
||||||
image: redis:7-alpine
|
image: redis:7-alpine
|
||||||
container_name: realtime-collab-redis
|
container_name: realtime-collab-redis
|
||||||
|
command: ["redis-server", "--appendonly", "yes"]
|
||||||
ports:
|
ports:
|
||||||
- "6379:6379"
|
- "6379:6379"
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
interval: 10s
|
interval: 10s
|
||||||
@@ -34,3 +37,4 @@ services:
|
|||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
|
redis_data:
|
||||||
|
|||||||
@@ -53,8 +53,11 @@ export const documentsApi = {
|
|||||||
},
|
},
|
||||||
|
|
||||||
// Get document Yjs state
|
// Get document Yjs state
|
||||||
getState: async (id: string): Promise<Uint8Array> => {
|
getState: async (id: string, shareToken?: string): Promise<Uint8Array> => {
|
||||||
const response = await authFetch(`${API_BASE_URL}/documents/${id}/state`);
|
const url = shareToken
|
||||||
|
? `${API_BASE_URL}/documents/${id}/state?share=${shareToken}`
|
||||||
|
: `${API_BASE_URL}/documents/${id}/state`;
|
||||||
|
const response = await authFetch(url);
|
||||||
if (!response.ok) throw new Error("Failed to fetch document state");
|
if (!response.ok) throw new Error("Failed to fetch document state");
|
||||||
const arrayBuffer = await response.arrayBuffer();
|
const arrayBuffer = await response.arrayBuffer();
|
||||||
return new Uint8Array(arrayBuffer);
|
return new Uint8Array(arrayBuffer);
|
||||||
@@ -167,4 +170,4 @@ export const versionsApi = {
|
|||||||
if (!response.ok) throw new Error('Failed to restore version');
|
if (!response.ok) throw new Error('Failed to restore version');
|
||||||
return response.json();
|
return response.json();
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -157,10 +157,34 @@ export const useYjsDocument = (documentId: string, shareToken?: string) => {
|
|||||||
setSynced(true);
|
setSynced(true);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Connection stability monitoring with reconnection limits
|
||||||
|
let reconnectCount = 0;
|
||||||
|
const maxReconnects = 10;
|
||||||
|
|
||||||
yjsProviders.websocketProvider.on(
|
yjsProviders.websocketProvider.on(
|
||||||
"status",
|
"status",
|
||||||
(event: { status: string }) => {
|
(event: { status: string }) => {
|
||||||
console.log("WebSocket status:", event.status);
|
console.log("WebSocket status:", event.status);
|
||||||
|
|
||||||
|
if (event.status === "disconnected") {
|
||||||
|
reconnectCount++;
|
||||||
|
if (reconnectCount >= maxReconnects) {
|
||||||
|
console.error(
|
||||||
|
"Max reconnection attempts reached. Please refresh the page."
|
||||||
|
);
|
||||||
|
// Could optionally show a user notification here
|
||||||
|
} else {
|
||||||
|
console.log(
|
||||||
|
`Reconnection attempt ${reconnectCount}/${maxReconnects}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else if (event.status === "connected") {
|
||||||
|
// Reset counter on successful connection
|
||||||
|
if (reconnectCount > 0) {
|
||||||
|
console.log("Reconnected successfully, resetting counter");
|
||||||
|
}
|
||||||
|
reconnectCount = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ export const createYjsDocument = async (
|
|||||||
|
|
||||||
// Load initial state from database BEFORE connecting providers
|
// Load initial state from database BEFORE connecting providers
|
||||||
try {
|
try {
|
||||||
const state = await documentsApi.getState(documentId);
|
const state = await documentsApi.getState(documentId, shareToken);
|
||||||
if (state && state.length > 0) {
|
if (state && state.length > 0) {
|
||||||
Y.applyUpdate(ydoc, state);
|
Y.applyUpdate(ydoc, state);
|
||||||
console.log('✓ Loaded document state from database');
|
console.log('✓ Loaded document state from database');
|
||||||
@@ -51,7 +51,10 @@ export const createYjsDocument = async (
|
|||||||
wsUrl,
|
wsUrl,
|
||||||
documentId,
|
documentId,
|
||||||
ydoc,
|
ydoc,
|
||||||
{ params: wsParams }
|
{
|
||||||
|
params: wsParams,
|
||||||
|
maxBackoffTime: 10000, // Max 10s between reconnect attempts
|
||||||
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
// Awareness for cursors and presence
|
// Awareness for cursors and presence
|
||||||
|
|||||||
Reference in New Issue
Block a user