feat: implement Redis Streams support with stream checkpoints and update history
- Added Redis Streams operations to the message bus interface and implementation. - Introduced StreamCheckpoint model to track last processed stream entry per document. - Implemented UpsertStreamCheckpoint and GetStreamCheckpoint methods in the Postgres store. - Created document_update_history table for storing update payloads for recovery and replay. - Developed update persist worker to handle Redis Stream updates and persist them to Postgres. - Enhanced Docker Compose configuration for Redis with persistence. - Updated frontend API to support fetching document state with optional share token. - Added connection stability monitoring in the Yjs document hook.
This commit is contained in:
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
@@ -20,41 +21,40 @@ const ContextUserIDKey = "user_id"
|
||||
type AuthMiddleware struct {
|
||||
store store.Store
|
||||
jwtSecret string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
// NewAuthMiddleware creates a new auth middleware
|
||||
func NewAuthMiddleware(store store.Store, jwtSecret string) *AuthMiddleware {
|
||||
func NewAuthMiddleware(store store.Store, jwtSecret string, logger *zap.Logger) *AuthMiddleware {
|
||||
if logger == nil {
|
||||
logger = zap.NewNop()
|
||||
}
|
||||
return &AuthMiddleware{
|
||||
store: store,
|
||||
jwtSecret: jwtSecret,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAuth middleware requires valid authentication
|
||||
func (m *AuthMiddleware) RequireAuth() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
fmt.Println("🔒 RequireAuth: Starting authentication check")
|
||||
|
||||
user, claims, err := m.getUserFromToken(c)
|
||||
|
||||
fmt.Printf("🔒 RequireAuth: user=%v, err=%v\n", user, err)
|
||||
if claims != nil {
|
||||
fmt.Printf("🔒 RequireAuth: claims.Name=%s, claims.Email=%s\n", claims.Name, claims.Email)
|
||||
}
|
||||
|
||||
if err != nil || user == nil {
|
||||
fmt.Printf("❌ RequireAuth: FAILED - err=%v, user=%v\n", err, user)
|
||||
if err != nil {
|
||||
m.logger.Warn("auth failed",
|
||||
zap.Error(err),
|
||||
zap.String("method", c.Request.Method),
|
||||
zap.String("path", c.FullPath()),
|
||||
)
|
||||
}
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Note: Name and Email might be empty for old JWT tokens
|
||||
if claims.Name == "" || claims.Email == "" {
|
||||
fmt.Printf("⚠️ RequireAuth: WARNING - Token missing name/email (using old token format)\n")
|
||||
}
|
||||
|
||||
fmt.Printf("✅ RequireAuth: SUCCESS - setting context for user %v\n", user)
|
||||
c.Set(ContextUserIDKey, user)
|
||||
c.Set("user_email", claims.Email)
|
||||
c.Set("user_name", claims.Name)
|
||||
@@ -88,21 +88,17 @@ func (m *AuthMiddleware) OptionalAuth() gin.HandlerFunc {
|
||||
// 注意:返回值变了,现在返回 (*uuid.UUID, *UserClaims, error)
|
||||
func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClaims, error) {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
fmt.Printf("🔍 getUserFromToken: Authorization header = '%s'\n", authHeader)
|
||||
|
||||
if authHeader == "" {
|
||||
fmt.Println("⚠️ getUserFromToken: No Authorization header")
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
parts := strings.Split(authHeader, " ")
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
fmt.Printf("⚠️ getUserFromToken: Invalid header format (parts=%d, prefix=%s)\n", len(parts), parts[0])
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
tokenString := parts[1]
|
||||
fmt.Printf("🔍 getUserFromToken: Token = %s...\n", tokenString[:min(20, len(tokenString))])
|
||||
|
||||
token, err := jwt.ParseWithClaims(tokenString, &UserClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
// 必须要验证签名算法是 HMAC (HS256)
|
||||
@@ -113,7 +109,6 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("❌ getUserFromToken: JWT parse error: %v\n", err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
@@ -123,17 +118,14 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
|
||||
// 因为我们在 GenerateJWT 里存的是 claims.Subject = userID.String()
|
||||
userID, err := uuid.Parse(claims.Subject)
|
||||
if err != nil {
|
||||
fmt.Printf("❌ getUserFromToken: Invalid UUID in subject: %v\n", err)
|
||||
return nil, nil, fmt.Errorf("invalid user ID in token")
|
||||
}
|
||||
|
||||
// 成功!直接返回 UUID 和 claims (里面包含 Name 和 Email)
|
||||
// 这一步完全没有查数据库,速度极快
|
||||
fmt.Printf("✅ getUserFromToken: SUCCESS - userID=%v, name=%s, email=%s\n", userID, claims.Name, claims.Email)
|
||||
return &userID, claims, nil
|
||||
}
|
||||
|
||||
fmt.Println("❌ getUserFromToken: Invalid token claims or token not valid")
|
||||
return nil, nil, fmt.Errorf("invalid token claims")
|
||||
}
|
||||
|
||||
@@ -141,8 +133,6 @@ func (m *AuthMiddleware) getUserFromToken(c *gin.Context) (*uuid.UUID, *UserClai
|
||||
func GetUserFromContext(c *gin.Context) *uuid.UUID {
|
||||
// 修正点:使用和存入时完全一样的 Key
|
||||
val, exists := c.Get(ContextUserIDKey)
|
||||
fmt.Println("within getFromContext the id is ... ")
|
||||
fmt.Println(val)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,22 +1,33 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type DocumentHandler struct {
|
||||
store *store.PostgresStore
|
||||
store *store.PostgresStore
|
||||
messageBus messagebus.MessageBus
|
||||
serverID string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
|
||||
return &DocumentHandler{store: s}
|
||||
func NewDocumentHandler(s *store.PostgresStore, msgBus messagebus.MessageBus, serverID string, logger *zap.Logger) *DocumentHandler {
|
||||
return &DocumentHandler{
|
||||
store: s,
|
||||
messageBus: msgBus,
|
||||
serverID: serverID,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDocument creates a new document (requires auth)
|
||||
@@ -45,8 +56,6 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
|
||||
|
||||
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
|
||||
userID := auth.GetUserFromContext(c)
|
||||
fmt.Println("Getting userId, which is : ")
|
||||
fmt.Println(userID)
|
||||
if userID == nil {
|
||||
respondUnauthorized(c, "Authentication required to list documents")
|
||||
return
|
||||
@@ -113,6 +122,13 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
||||
}
|
||||
|
||||
userID := auth.GetUserFromContext(c)
|
||||
shareToken := c.Query("share")
|
||||
|
||||
doc, err := h.store.GetDocument(id)
|
||||
if err != nil {
|
||||
respondNotFound(c, "document")
|
||||
return
|
||||
}
|
||||
|
||||
// Check permission if authenticated
|
||||
if userID != nil {
|
||||
@@ -125,12 +141,22 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
||||
respondForbidden(c, "Access denied")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
doc, err := h.store.GetDocument(id)
|
||||
if err != nil {
|
||||
respondNotFound(c, "document")
|
||||
return
|
||||
} else {
|
||||
// Unauthenticated: require valid share token or public doc
|
||||
if shareToken != "" {
|
||||
valid, err := h.store.ValidateShareToken(c.Request.Context(), id, shareToken)
|
||||
if err != nil {
|
||||
respondInternalError(c, "Failed to validate share token", err)
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
respondForbidden(c, "Invalid or expired share token")
|
||||
return
|
||||
}
|
||||
} else if !doc.Is_Public {
|
||||
respondForbidden(c, "This document is not public. Please sign in to access.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return empty byte slice if state is nil (new document)
|
||||
@@ -191,6 +217,16 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if streamID, seq, ok := h.addSnapshotMarker(c.Request.Context(), id); ok {
|
||||
if err := h.store.UpsertStreamCheckpoint(c.Request.Context(), id, streamID, seq); err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("Failed to upsert stream checkpoint after snapshot",
|
||||
zap.String("document_id", id.String()),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
|
||||
}
|
||||
|
||||
@@ -234,6 +270,43 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"})
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) addSnapshotMarker(ctx context.Context, documentID uuid.UUID) (string, int64, bool) {
|
||||
if h.messageBus == nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
streamKey := "stream:" + documentID.String()
|
||||
seqKey := "seq:" + documentID.String()
|
||||
|
||||
seq, err := h.messageBus.Incr(ctx, seqKey)
|
||||
if err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("Failed to increment snapshot sequence",
|
||||
zap.String("document_id", documentID.String()),
|
||||
zap.Error(err))
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
values := map[string]interface{}{
|
||||
"type": "snapshot",
|
||||
"server_id": h.serverID,
|
||||
"seq": seq,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
streamID, err := h.messageBus.XAdd(ctx, streamKey, 10000, true, values)
|
||||
if err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("Failed to add snapshot marker to stream",
|
||||
zap.String("stream_key", streamKey),
|
||||
zap.Error(err))
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
return streamID, seq, true
|
||||
}
|
||||
|
||||
// GetDocumentPermission returns the user's permission level for a document
|
||||
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
|
||||
documentID, err := uuid.Parse(c.Param("id"))
|
||||
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DocumentHandlerSuite tests document CRUD operations
|
||||
@@ -23,7 +25,7 @@ type DocumentHandlerSuite struct {
|
||||
// SetupTest runs before each test
|
||||
func (s *DocumentHandlerSuite) SetupTest() {
|
||||
s.BaseHandlerSuite.SetupTest()
|
||||
s.handler = NewDocumentHandler(s.store)
|
||||
s.handler = NewDocumentHandler(s.store, messagebus.NewLocalMessageBus(), "test-server", zap.NewNop())
|
||||
s.setupRouter()
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ShareHandlerSuite tests for share handler endpoints
|
||||
@@ -24,7 +25,7 @@ func (s *ShareHandlerSuite) SetupTest() {
|
||||
s.BaseHandlerSuite.SetupTest()
|
||||
|
||||
// Create handler and router
|
||||
authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret)
|
||||
authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret, zap.NewNop())
|
||||
s.handler = NewShareHandler(s.store, s.cfg)
|
||||
s.router = gin.New()
|
||||
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/config"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/hub"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -19,16 +24,18 @@ import (
|
||||
var connectionSem = make(chan struct{}, 200)
|
||||
|
||||
type WebSocketHandler struct {
|
||||
hub *hub.Hub
|
||||
store store.Store
|
||||
cfg *config.Config
|
||||
hub *hub.Hub
|
||||
store store.Store
|
||||
cfg *config.Config
|
||||
msgBus messagebus.MessageBus
|
||||
}
|
||||
|
||||
func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config) *WebSocketHandler {
|
||||
func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config, msgBus messagebus.MessageBus) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
hub: h,
|
||||
store: s,
|
||||
cfg: cfg,
|
||||
hub: h,
|
||||
store: s,
|
||||
cfg: cfg,
|
||||
msgBus: msgBus,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,6 +177,105 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||
// Start goroutines
|
||||
go client.WritePump()
|
||||
go client.ReadPump()
|
||||
go wsh.replayBacklog(client, documentID)
|
||||
|
||||
log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID)
|
||||
}
|
||||
|
||||
const maxReplayUpdates = 5000
|
||||
|
||||
func (wsh *WebSocketHandler) replayBacklog(client *hub.Client, documentID uuid.UUID) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
checkpoint, err := wsh.store.GetStreamCheckpoint(ctx, documentID)
|
||||
if err != nil || checkpoint == nil || checkpoint.LastStreamID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
streamKey := "stream:" + documentID.String()
|
||||
var sent int
|
||||
|
||||
// Primary: Redis stream replay
|
||||
if wsh.msgBus != nil {
|
||||
messages, err := wsh.msgBus.XRange(ctx, streamKey, checkpoint.LastStreamID, "+")
|
||||
if err == nil && len(messages) > 0 {
|
||||
for _, msg := range messages {
|
||||
if msg.ID == checkpoint.LastStreamID {
|
||||
continue
|
||||
}
|
||||
if sent >= maxReplayUpdates {
|
||||
log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String())
|
||||
return
|
||||
}
|
||||
msgType := getString(msg.Values["type"])
|
||||
if msgType != "update" {
|
||||
continue
|
||||
}
|
||||
seq := parseInt64(msg.Values["seq"])
|
||||
if seq <= checkpoint.LastSeq {
|
||||
continue
|
||||
}
|
||||
payloadB64 := getString(msg.Values["yjs_payload"])
|
||||
payload, err := base64.StdEncoding.DecodeString(payloadB64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if client.Enqueue(payload) {
|
||||
sent++
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: DB history replay
|
||||
updates, err := wsh.store.ListUpdateHistoryAfterSeq(ctx, documentID, checkpoint.LastSeq, maxReplayUpdates)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, upd := range updates {
|
||||
if sent >= maxReplayUpdates {
|
||||
log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String())
|
||||
return
|
||||
}
|
||||
if client.Enqueue(upd.Payload) {
|
||||
sent++
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getString(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
default:
|
||||
return fmt.Sprint(v)
|
||||
}
|
||||
}
|
||||
|
||||
func parseInt64(value interface{}) int64 {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case int:
|
||||
return int64(v)
|
||||
case uint64:
|
||||
return int64(v)
|
||||
case string:
|
||||
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return parsed
|
||||
}
|
||||
case []byte:
|
||||
if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package hub
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -37,10 +39,11 @@ type Client struct {
|
||||
idsMu sync.Mutex
|
||||
}
|
||||
type Room struct {
|
||||
ID string
|
||||
clients map[*Client]bool
|
||||
mu sync.RWMutex
|
||||
cancel context.CancelFunc
|
||||
ID string
|
||||
clients map[*Client]bool
|
||||
mu sync.RWMutex
|
||||
cancel context.CancelFunc
|
||||
reconnectCount int // Track Redis reconnection attempts for debugging
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
@@ -64,6 +67,10 @@ type Hub struct {
|
||||
|
||||
// Bounded worker pool for Redis SetAwareness
|
||||
awarenessQueue chan awarenessItem
|
||||
|
||||
// Stream persistence worker pool (P1: Redis Streams durability)
|
||||
streamQueue chan *Message // buffered queue for XADD operations
|
||||
streamDone chan struct{} // close to signal stream workers to exit
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -79,6 +86,13 @@ const (
|
||||
|
||||
// awarenessQueueSize is the buffer size for awareness updates.
|
||||
awarenessQueueSize = 4096
|
||||
|
||||
// streamWorkerCount is the number of fixed goroutines consuming from streamQueue.
|
||||
// 50 workers match publish workers for consistent throughput.
|
||||
streamWorkerCount = 50
|
||||
|
||||
// streamQueueSize is the buffer size for the stream persistence queue.
|
||||
streamQueueSize = 4096
|
||||
)
|
||||
|
||||
type awarenessItem struct {
|
||||
@@ -103,11 +117,15 @@ func NewHub(messagebus messagebus.MessageBus, serverID string, logger *zap.Logge
|
||||
publishDone: make(chan struct{}),
|
||||
// bounded awareness worker pool
|
||||
awarenessQueue: make(chan awarenessItem, awarenessQueueSize),
|
||||
// Stream persistence worker pool
|
||||
streamQueue: make(chan *Message, streamQueueSize),
|
||||
streamDone: make(chan struct{}),
|
||||
}
|
||||
|
||||
// Start the fixed worker pool for Redis publishing
|
||||
h.startPublishWorkers(publishWorkerCount)
|
||||
h.startAwarenessWorkers(awarenessWorkerCount)
|
||||
h.startStreamWorkers(streamWorkerCount)
|
||||
|
||||
return h
|
||||
}
|
||||
@@ -173,6 +191,82 @@ func (h *Hub) startAwarenessWorkers(n int) {
|
||||
h.logger.Info("Awareness worker pool started", zap.Int("workers", n))
|
||||
}
|
||||
|
||||
// startStreamWorkers launches n goroutines that consume from streamQueue
|
||||
// and add messages to Redis Streams for durability and replay.
|
||||
func (h *Hub) startStreamWorkers(n int) {
|
||||
for i := 0; i < n; i++ {
|
||||
go func(workerID int) {
|
||||
for {
|
||||
select {
|
||||
case <-h.streamDone:
|
||||
h.logger.Info("Stream worker exiting", zap.Int("worker_id", workerID))
|
||||
return
|
||||
case msg, ok := <-h.streamQueue:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
h.addToStream(msg)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
h.logger.Info("Stream worker pool started", zap.Int("workers", n))
|
||||
}
|
||||
|
||||
// encodeBase64 encodes binary data to base64 string for Redis storage
|
||||
func encodeBase64(data []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// addToStream adds a message to Redis Streams for durability
|
||||
func (h *Hub) addToStream(msg *Message) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
streamKey := "stream:" + msg.RoomID
|
||||
|
||||
// Get next sequence number atomically
|
||||
seqKey := "seq:" + msg.RoomID
|
||||
seq, err := h.messagebus.Incr(ctx, seqKey)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to increment sequence",
|
||||
zap.String("room_id", msg.RoomID),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Encode payload as base64 (binary-safe storage)
|
||||
payload := encodeBase64(msg.Data)
|
||||
|
||||
// Extract Yjs message type from first byte as numeric string
|
||||
msgType := "0"
|
||||
if len(msg.Data) > 0 {
|
||||
msgType = strconv.Itoa(int(msg.Data[0]))
|
||||
}
|
||||
|
||||
// Add entry to Stream with MAXLEN trimming
|
||||
values := map[string]interface{}{
|
||||
"type": "update",
|
||||
"server_id": h.serverID,
|
||||
"yjs_payload": payload,
|
||||
"msg_type": msgType,
|
||||
"seq": seq,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
_, err = h.messagebus.XAdd(ctx, streamKey, 10000, true, values)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to add to Stream",
|
||||
zap.String("stream_key", streamKey),
|
||||
zap.Int64("seq", seq),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
// Mark this document as active so the persist worker only processes active streams
|
||||
_ = h.messagebus.ZAdd(ctx, "active-streams", float64(time.Now().Unix()), msg.RoomID)
|
||||
}
|
||||
|
||||
func (h *Hub) Run() {
|
||||
for {
|
||||
select {
|
||||
@@ -471,6 +565,7 @@ func (h *Hub) broadcastMessage(message *Message) {
|
||||
// 只有本地客户端发出的消息 (sender != nil) 才推送到 Redis
|
||||
// P0 fix: send to bounded worker pool instead of spawning unbounded goroutines
|
||||
if message.sender != nil && !h.fallbackMode && h.messagebus != nil {
|
||||
// 3a. Publish to Pub/Sub (real-time cross-server broadcast)
|
||||
select {
|
||||
case h.publishQueue <- message:
|
||||
// Successfully queued for async publish by worker pool
|
||||
@@ -479,6 +574,19 @@ func (h *Hub) broadcastMessage(message *Message) {
|
||||
h.logger.Warn("Publish queue full, dropping Redis publish",
|
||||
zap.String("room_id", message.RoomID))
|
||||
}
|
||||
|
||||
// 3b. Add to Stream for durability (only Type 0 updates, not Type 1 awareness)
|
||||
// Type 0 = Yjs sync/update messages (document changes)
|
||||
// Type 1 = Yjs awareness messages (cursors, presence) - ephemeral, skip
|
||||
if len(message.Data) > 0 && message.Data[0] == 0 {
|
||||
select {
|
||||
case h.streamQueue <- message:
|
||||
// Successfully queued for async Stream add
|
||||
default:
|
||||
h.logger.Warn("Stream queue full, dropping durability",
|
||||
zap.String("room_id", message.RoomID))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -504,10 +612,28 @@ func (h *Hub) broadcastToLocalClients(room *Room, data []byte, sender *Client) {
|
||||
}
|
||||
}
|
||||
func (h *Hub) startRoomMessageForwarding(ctx context.Context, roomID string, msgChan <-chan []byte) {
|
||||
h.logger.Info("Starting message forwarding from Redis to room",
|
||||
zap.String("room_id", roomID),
|
||||
zap.String("server_id", h.serverID),
|
||||
)
|
||||
// Increment and log reconnection count for debugging
|
||||
h.mu.RLock()
|
||||
room, exists := h.rooms[roomID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
room.mu.Lock()
|
||||
room.reconnectCount++
|
||||
reconnectCount := room.reconnectCount
|
||||
room.mu.Unlock()
|
||||
|
||||
h.logger.Info("Starting message forwarding from Redis to room",
|
||||
zap.String("room_id", roomID),
|
||||
zap.String("server_id", h.serverID),
|
||||
zap.Int("reconnect_count", reconnectCount),
|
||||
)
|
||||
} else {
|
||||
h.logger.Info("Starting message forwarding from Redis to room",
|
||||
zap.String("room_id", roomID),
|
||||
zap.String("server_id", h.serverID),
|
||||
)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
@@ -791,12 +917,28 @@ func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string
|
||||
UserAvatar: userAvatar,
|
||||
Permission: permission,
|
||||
Conn: conn,
|
||||
send: make(chan []byte, 1024),
|
||||
send: make(chan []byte, 8192),
|
||||
hub: hub,
|
||||
roomID: roomID,
|
||||
observedYjsIDs: make(map[uint64]uint64),
|
||||
}
|
||||
}
|
||||
|
||||
// Enqueue sends a message to the client send buffer (non-blocking).
|
||||
// Returns false if the buffer is full.
|
||||
func (c *Client) Enqueue(message []byte) bool {
|
||||
select {
|
||||
case c.send <- message:
|
||||
return true
|
||||
default:
|
||||
if c.hub != nil && c.hub.logger != nil {
|
||||
c.hub.logger.Warn("Client send buffer full during replay",
|
||||
zap.String("client_id", c.ID),
|
||||
zap.String("room_id", c.roomID))
|
||||
}
|
||||
return false
|
||||
}
|
||||
}
|
||||
func (c *Client) unregister() {
|
||||
c.unregisterOnce.Do(func() {
|
||||
c.hub.Unregister <- c
|
||||
|
||||
@@ -23,7 +23,7 @@ func NewLogger(isDevelopment bool) (*zap.Logger, error) {
|
||||
// 👇 关键修改:直接拉到 Fatal 级别
|
||||
// 这样 Error, Warn, Info, Debug 全部都会被忽略
|
||||
// 彻底消除 IO 锁竞争
|
||||
config.Level = zap.NewAtomicLevelAt(zapcore.FatalLevel)
|
||||
config.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
|
||||
|
||||
logger, err := config.Build()
|
||||
if err != nil {
|
||||
|
||||
@@ -2,6 +2,7 @@ package messagebus
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// MessageBus abstracts message distribution across server instances
|
||||
@@ -33,6 +34,72 @@ type MessageBus interface {
|
||||
|
||||
// Close gracefully shuts down the message bus
|
||||
Close() error
|
||||
|
||||
// ========== Redis Streams Operations ==========
|
||||
|
||||
// XAdd adds a new entry to a stream with optional MAXLEN trimming
|
||||
XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error)
|
||||
|
||||
// XReadGroup reads messages from a stream using a consumer group
|
||||
XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error)
|
||||
|
||||
// XAck acknowledges one or more messages from a consumer group
|
||||
XAck(ctx context.Context, stream, group string, ids ...string) (int64, error)
|
||||
|
||||
// XGroupCreate creates a new consumer group for a stream
|
||||
XGroupCreate(ctx context.Context, stream, group, start string) error
|
||||
|
||||
// XGroupCreateMkStream creates a consumer group and the stream if it doesn't exist
|
||||
XGroupCreateMkStream(ctx context.Context, stream, group, start string) error
|
||||
|
||||
// XPending returns pending messages information for a consumer group
|
||||
XPending(ctx context.Context, stream, group string) (*PendingInfo, error)
|
||||
|
||||
// XClaim claims pending messages from a consumer group
|
||||
XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error)
|
||||
|
||||
// XAutoClaim claims pending messages automatically (Redis >= 6.2)
|
||||
// Returns claimed messages and next start ID.
|
||||
XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error)
|
||||
|
||||
// XRange reads a range of messages from a stream
|
||||
XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error)
|
||||
|
||||
// XTrimMinID trims a stream to a minimum ID (time-based retention)
|
||||
XTrimMinID(ctx context.Context, stream, minID string) (int64, error)
|
||||
|
||||
// Incr increments a counter atomically (for sequence numbers)
|
||||
Incr(ctx context.Context, key string) (int64, error)
|
||||
|
||||
// ========== Sorted Set (ZSET) Operations ==========
|
||||
|
||||
// ZAdd adds a member with a score to a sorted set (used for active-stream tracking)
|
||||
ZAdd(ctx context.Context, key string, score float64, member string) error
|
||||
|
||||
// ZRangeByScore returns members with scores between min and max
|
||||
ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error)
|
||||
|
||||
// ZRemRangeByScore removes members with scores between min and max
|
||||
ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error)
|
||||
|
||||
// Distributed lock helpers (used by background workers)
|
||||
AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error)
|
||||
RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error)
|
||||
ReleaseLock(ctx context.Context, key string) error
|
||||
}
|
||||
|
||||
// StreamMessage represents a message from a Redis Stream
|
||||
type StreamMessage struct {
|
||||
ID string
|
||||
Values map[string]interface{}
|
||||
}
|
||||
|
||||
// PendingInfo contains information about pending messages in a consumer group
|
||||
type PendingInfo struct {
|
||||
Count int64
|
||||
Lower string
|
||||
Upper string
|
||||
Consumers map[string]int64
|
||||
}
|
||||
|
||||
// LocalMessageBus is a no-op implementation for single-server mode
|
||||
@@ -78,3 +145,73 @@ func (l *LocalMessageBus) IsHealthy() bool {
|
||||
func (l *LocalMessageBus) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ========== Redis Streams Operations (No-op for local mode) ==========
|
||||
|
||||
func (l *LocalMessageBus) XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) {
|
||||
return "0-0", nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) XGroupCreate(ctx context.Context, stream, group, start string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) XPending(ctx context.Context, stream, group string) (*PendingInfo, error) {
|
||||
return &PendingInfo{}, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) {
|
||||
return nil, "0-0", nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) XTrimMinID(ctx context.Context, stream, minID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) Incr(ctx context.Context, key string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) ZAdd(ctx context.Context, key string, score float64, member string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) ReleaseLock(ctx context.Context, key string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -88,6 +89,23 @@ func NewRedisMessageBus(redisURL string, serverID string, logger *zap.Logger) (*
|
||||
// - Redis will handle stale connections via TCP keepalive
|
||||
opts.ConnMaxLifetime = 1 * time.Hour
|
||||
|
||||
// ================================
|
||||
// Socket-Level Timeout Configuration (prevents indefinite hangs)
|
||||
// ================================
|
||||
// Without these, TCP reads/writes block indefinitely when Redis is unresponsive,
|
||||
// causing OS-level timeouts (60-120s) instead of application-level control.
|
||||
|
||||
// DialTimeout: How long to wait for initial connection establishment
|
||||
opts.DialTimeout = 5 * time.Second
|
||||
|
||||
// ReadTimeout: Maximum time for socket read operations
|
||||
// - 30s is appropriate for PubSub (long intervals between messages are normal)
|
||||
// - Prevents indefinite blocking when Redis hangs
|
||||
opts.ReadTimeout = 30 * time.Second
|
||||
|
||||
// WriteTimeout: Maximum time for socket write operations
|
||||
opts.WriteTimeout = 5 * time.Second
|
||||
|
||||
client := goredis.NewClient(opts)
|
||||
|
||||
// ================================
|
||||
@@ -215,12 +233,15 @@ func (r *RedisMessageBus) readLoop(ctx context.Context, roomID string, sub *subs
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
r.logger.Warn("PubSub initial subscription failed, retrying with backoff",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Error(err),
|
||||
zap.Duration("backoff", backoff),
|
||||
)
|
||||
time.Sleep(backoff)
|
||||
if backoff < maxBackoff {
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
backoff = backoff * 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -242,12 +263,15 @@ func (r *RedisMessageBus) readLoop(ctx context.Context, roomID string, sub *subs
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
r.logger.Warn("PubSub receive failed, retrying with backoff",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Error(err),
|
||||
zap.Duration("backoff", backoff),
|
||||
)
|
||||
time.Sleep(backoff)
|
||||
if backoff < maxBackoff {
|
||||
backoff *= 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
backoff = backoff * 2
|
||||
if backoff > maxBackoff {
|
||||
backoff = maxBackoff
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -261,12 +285,15 @@ func (r *RedisMessageBus) receiveOnce(ctx context.Context, roomID string, pubsub
|
||||
|
||||
msg, err := pubsub.ReceiveTimeout(ctx, 5*time.Second)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
|
||||
return err
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if errors.Is(err, goredis.Nil) {
|
||||
continue
|
||||
}
|
||||
if isTimeoutErr(err) {
|
||||
continue
|
||||
}
|
||||
r.logger.Warn("pubsub receive error, closing subscription",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Error(err),
|
||||
@@ -308,6 +335,17 @@ func (r *RedisMessageBus) receiveOnce(ctx context.Context, roomID string, pubsub
|
||||
}
|
||||
}
|
||||
|
||||
func isTimeoutErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return true
|
||||
}
|
||||
var netErr net.Error
|
||||
return errors.As(err, &netErr) && netErr.Timeout()
|
||||
}
|
||||
|
||||
// Unsubscribe stops listening to a room
|
||||
func (r *RedisMessageBus) Unsubscribe(ctx context.Context, roomID string) error {
|
||||
r.subMu.Lock()
|
||||
@@ -430,7 +468,7 @@ func (r *RedisMessageBus) DeleteAwareness(ctx context.Context, roomID string, cl
|
||||
|
||||
// IsHealthy checks Redis connectivity
|
||||
func (r *RedisMessageBus) IsHealthy() bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 只有 Ping 成功且没有报错,才认为服务是健康的
|
||||
@@ -516,3 +554,223 @@ func (r *RedisMessageBus) ClearAllAwareness(ctx context.Context, roomID string)
|
||||
// 直接使用 Del 命令删除整个 Key
|
||||
return r.client.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// ========== Redis Streams Operations ==========
|
||||
|
||||
// XAdd adds a new entry to a stream with optional MAXLEN trimming
|
||||
func (r *RedisMessageBus) XAdd(ctx context.Context, stream string, maxLen int64, approx bool, values map[string]interface{}) (string, error) {
|
||||
result := r.client.XAdd(ctx, &goredis.XAddArgs{
|
||||
Stream: stream,
|
||||
MaxLen: maxLen,
|
||||
Approx: approx,
|
||||
Values: values,
|
||||
})
|
||||
return result.Val(), result.Err()
|
||||
}
|
||||
|
||||
// XReadGroup reads messages from a stream using a consumer group
|
||||
func (r *RedisMessageBus) XReadGroup(ctx context.Context, group, consumer string, streams []string, count int64, block time.Duration) ([]StreamMessage, error) {
|
||||
result := r.client.XReadGroup(ctx, &goredis.XReadGroupArgs{
|
||||
Group: group,
|
||||
Consumer: consumer,
|
||||
Streams: streams,
|
||||
Count: count,
|
||||
Block: block,
|
||||
})
|
||||
|
||||
if err := result.Err(); err != nil {
|
||||
// Timeout is not an error, just no new messages
|
||||
if err == goredis.Nil {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert go-redis XStream to our StreamMessage format
|
||||
var messages []StreamMessage
|
||||
for _, stream := range result.Val() {
|
||||
for _, msg := range stream.Messages {
|
||||
messages = append(messages, StreamMessage{
|
||||
ID: msg.ID,
|
||||
Values: msg.Values,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// XAck acknowledges one or more messages from a consumer group
|
||||
func (r *RedisMessageBus) XAck(ctx context.Context, stream, group string, ids ...string) (int64, error) {
|
||||
result := r.client.XAck(ctx, stream, group, ids...)
|
||||
return result.Val(), result.Err()
|
||||
}
|
||||
|
||||
// XGroupCreate creates a new consumer group for a stream
|
||||
func (r *RedisMessageBus) XGroupCreate(ctx context.Context, stream, group, start string) error {
|
||||
return r.client.XGroupCreate(ctx, stream, group, start).Err()
|
||||
}
|
||||
|
||||
// XGroupCreateMkStream creates a consumer group and the stream if it doesn't exist
|
||||
func (r *RedisMessageBus) XGroupCreateMkStream(ctx context.Context, stream, group, start string) error {
|
||||
return r.client.XGroupCreateMkStream(ctx, stream, group, start).Err()
|
||||
}
|
||||
|
||||
// XPending returns pending messages information for a consumer group
|
||||
func (r *RedisMessageBus) XPending(ctx context.Context, stream, group string) (*PendingInfo, error) {
|
||||
result := r.client.XPending(ctx, stream, group)
|
||||
if err := result.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pending := result.Val()
|
||||
consumers := make(map[string]int64)
|
||||
for name, count := range pending.Consumers {
|
||||
consumers[name] = count
|
||||
}
|
||||
|
||||
return &PendingInfo{
|
||||
Count: pending.Count,
|
||||
Lower: pending.Lower,
|
||||
Upper: pending.Higher, // go-redis uses "Higher" instead of "Upper"
|
||||
Consumers: consumers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// XClaim claims pending messages from a consumer group
|
||||
func (r *RedisMessageBus) XClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, ids ...string) ([]StreamMessage, error) {
|
||||
result := r.client.XClaim(ctx, &goredis.XClaimArgs{
|
||||
Stream: stream,
|
||||
Group: group,
|
||||
Consumer: consumer,
|
||||
MinIdle: minIdleTime,
|
||||
Messages: ids,
|
||||
})
|
||||
|
||||
if err := result.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert go-redis XMessage to our StreamMessage format
|
||||
var messages []StreamMessage
|
||||
for _, msg := range result.Val() {
|
||||
messages = append(messages, StreamMessage{
|
||||
ID: msg.ID,
|
||||
Values: msg.Values,
|
||||
})
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// XAutoClaim claims pending messages automatically (Redis >= 6.2)
|
||||
func (r *RedisMessageBus) XAutoClaim(ctx context.Context, stream, group, consumer string, minIdleTime time.Duration, start string, count int64) ([]StreamMessage, string, error) {
|
||||
result := r.client.XAutoClaim(ctx, &goredis.XAutoClaimArgs{
|
||||
Stream: stream,
|
||||
Group: group,
|
||||
Consumer: consumer,
|
||||
MinIdle: minIdleTime,
|
||||
Start: start,
|
||||
Count: count,
|
||||
})
|
||||
msgs, nextStart, err := result.Result()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
messages := make([]StreamMessage, 0, len(msgs))
|
||||
for _, msg := range msgs {
|
||||
messages = append(messages, StreamMessage{
|
||||
ID: msg.ID,
|
||||
Values: msg.Values,
|
||||
})
|
||||
}
|
||||
return messages, nextStart, nil
|
||||
}
|
||||
|
||||
// XRange reads a range of messages from a stream
|
||||
func (r *RedisMessageBus) XRange(ctx context.Context, stream, start, end string) ([]StreamMessage, error) {
|
||||
result := r.client.XRange(ctx, stream, start, end)
|
||||
if err := result.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Convert go-redis XMessage to our StreamMessage format
|
||||
var messages []StreamMessage
|
||||
for _, msg := range result.Val() {
|
||||
messages = append(messages, StreamMessage{
|
||||
ID: msg.ID,
|
||||
Values: msg.Values,
|
||||
})
|
||||
}
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
// XTrimMinID trims a stream to a minimum ID (time-based retention)
|
||||
func (r *RedisMessageBus) XTrimMinID(ctx context.Context, stream, minID string) (int64, error) {
|
||||
// Use XTRIM with MINID and approximation (~) for efficiency
|
||||
// LIMIT clause prevents blocking Redis during large trims
|
||||
result := r.client.Do(ctx, "XTRIM", stream, "MINID", "~", minID, "LIMIT", 1000)
|
||||
if err := result.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Result is the number of entries removed
|
||||
trimmed, err := result.Int64()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return trimmed, nil
|
||||
}
|
||||
|
||||
// ========== Sorted Set (ZSET) Operations ==========
|
||||
|
||||
// ZAdd adds a member with a score to a sorted set
|
||||
func (r *RedisMessageBus) ZAdd(ctx context.Context, key string, score float64, member string) error {
|
||||
return r.client.ZAdd(ctx, key, goredis.Z{Score: score, Member: member}).Err()
|
||||
}
|
||||
|
||||
// ZRangeByScore returns members with scores between min and max
|
||||
func (r *RedisMessageBus) ZRangeByScore(ctx context.Context, key string, min, max float64) ([]string, error) {
|
||||
return r.client.ZRangeByScore(ctx, key, &goredis.ZRangeBy{
|
||||
Min: strconv.FormatFloat(min, 'f', -1, 64),
|
||||
Max: strconv.FormatFloat(max, 'f', -1, 64),
|
||||
}).Result()
|
||||
}
|
||||
|
||||
// ZRemRangeByScore removes members with scores between min and max
|
||||
func (r *RedisMessageBus) ZRemRangeByScore(ctx context.Context, key string, min, max float64) (int64, error) {
|
||||
return r.client.ZRemRangeByScore(ctx, key,
|
||||
strconv.FormatFloat(min, 'f', -1, 64),
|
||||
strconv.FormatFloat(max, 'f', -1, 64),
|
||||
).Result()
|
||||
}
|
||||
|
||||
// Incr increments a counter atomically (for sequence numbers)
|
||||
func (r *RedisMessageBus) Incr(ctx context.Context, key string) (int64, error) {
|
||||
result := r.client.Incr(ctx, key)
|
||||
return result.Val(), result.Err()
|
||||
}
|
||||
|
||||
// AcquireLock attempts to acquire a distributed lock with TTL
|
||||
func (r *RedisMessageBus) AcquireLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
|
||||
return r.client.SetNX(ctx, key, r.serverID, ttl).Result()
|
||||
}
|
||||
|
||||
// RefreshLock extends the TTL on an existing lock
|
||||
func (r *RedisMessageBus) RefreshLock(ctx context.Context, key string, ttl time.Duration) (bool, error) {
|
||||
result := r.client.SetArgs(ctx, key, r.serverID, goredis.SetArgs{
|
||||
Mode: "XX",
|
||||
TTL: ttl,
|
||||
})
|
||||
if err := result.Err(); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result.Val() == "OK", nil
|
||||
}
|
||||
|
||||
// ReleaseLock releases a distributed lock
|
||||
func (r *RedisMessageBus) ReleaseLock(ctx context.Context, key string) error {
|
||||
return r.client.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
15
backend/internal/models/stream_checkpoint.go
Normal file
15
backend/internal/models/stream_checkpoint.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// StreamCheckpoint tracks the last processed Redis Stream entry per document
|
||||
type StreamCheckpoint struct {
|
||||
DocumentID uuid.UUID `json:"document_id"`
|
||||
LastStreamID string `json:"last_stream_id"`
|
||||
LastSeq int64 `json:"last_seq"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
@@ -53,6 +53,15 @@ type Store interface {
|
||||
GetDocumentVersion(ctx context.Context, versionID uuid.UUID) (*models.DocumentVersion, error)
|
||||
GetLatestDocumentVersion(ctx context.Context, documentID uuid.UUID) (*models.DocumentVersion, error)
|
||||
|
||||
// Stream checkpoint operations
|
||||
UpsertStreamCheckpoint(ctx context.Context, documentID uuid.UUID, streamID string, seq int64) error
|
||||
GetStreamCheckpoint(ctx context.Context, documentID uuid.UUID) (*models.StreamCheckpoint, error)
|
||||
|
||||
// Update history (WAL) operations
|
||||
InsertUpdateHistoryBatch(ctx context.Context, entries []UpdateHistoryEntry) error
|
||||
ListUpdateHistoryAfterSeq(ctx context.Context, documentID uuid.UUID, afterSeq int64, limit int) ([]UpdateHistoryEntry, error)
|
||||
DeleteUpdateHistoryUpToSeq(ctx context.Context, documentID uuid.UUID, maxSeq int64) error
|
||||
|
||||
Close() error
|
||||
}
|
||||
|
||||
|
||||
46
backend/internal/store/stream_checkpoint.go
Normal file
46
backend/internal/store/stream_checkpoint.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UpsertStreamCheckpoint creates or updates the stream checkpoint for a document
|
||||
func (s *PostgresStore) UpsertStreamCheckpoint(ctx context.Context, documentID uuid.UUID, streamID string, seq int64) error {
|
||||
query := `
|
||||
INSERT INTO stream_checkpoints (document_id, last_stream_id, last_seq, updated_at)
|
||||
VALUES ($1, $2, $3, NOW())
|
||||
ON CONFLICT (document_id)
|
||||
DO UPDATE SET last_stream_id = EXCLUDED.last_stream_id,
|
||||
last_seq = EXCLUDED.last_seq,
|
||||
updated_at = NOW()
|
||||
`
|
||||
|
||||
if _, err := s.db.ExecContext(ctx, query, documentID, streamID, seq); err != nil {
|
||||
return fmt.Errorf("failed to upsert stream checkpoint: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStreamCheckpoint retrieves the stream checkpoint for a document
|
||||
func (s *PostgresStore) GetStreamCheckpoint(ctx context.Context, documentID uuid.UUID) (*models.StreamCheckpoint, error) {
|
||||
query := `
|
||||
SELECT document_id, last_stream_id, last_seq, updated_at
|
||||
FROM stream_checkpoints
|
||||
WHERE document_id = $1
|
||||
`
|
||||
|
||||
var checkpoint models.StreamCheckpoint
|
||||
if err := s.db.QueryRowContext(ctx, query, documentID).Scan(
|
||||
&checkpoint.DocumentID,
|
||||
&checkpoint.LastStreamID,
|
||||
&checkpoint.LastSeq,
|
||||
&checkpoint.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("failed to get stream checkpoint: %w", err)
|
||||
}
|
||||
return &checkpoint, nil
|
||||
}
|
||||
@@ -71,10 +71,14 @@ func SetupTestDB(t *testing.T) (*PostgresStore, func()) {
|
||||
// Run migrations
|
||||
scriptsDir := filepath.Join("..", "..", "scripts")
|
||||
migrations := []string{
|
||||
"init.sql",
|
||||
"001_add_users_and_sessions.sql",
|
||||
"002_add_document_shares.sql",
|
||||
"003_add_public_sharing.sql",
|
||||
"000_extensions.sql",
|
||||
"001_init_schema.sql",
|
||||
"002_add_users_and_sessions.sql",
|
||||
"003_add_document_shares.sql",
|
||||
"004_add_public_sharing.sql",
|
||||
"005_add_share_link_permission.sql",
|
||||
"010_add_stream_checkpoints.sql",
|
||||
"011_add_update_history.sql",
|
||||
}
|
||||
|
||||
for _, migration := range migrations {
|
||||
@@ -107,6 +111,8 @@ func SetupTestDB(t *testing.T) (*PostgresStore, func()) {
|
||||
func TruncateAllTables(ctx context.Context, store *PostgresStore) error {
|
||||
tables := []string{
|
||||
"document_updates",
|
||||
"document_update_history",
|
||||
"stream_checkpoints",
|
||||
"document_shares",
|
||||
"sessions",
|
||||
"documents",
|
||||
|
||||
115
backend/internal/store/update_history.go
Normal file
115
backend/internal/store/update_history.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// UpdateHistoryEntry represents a persisted update from Redis Streams
|
||||
// used for recovery and replay.
|
||||
type UpdateHistoryEntry struct {
|
||||
DocumentID uuid.UUID
|
||||
StreamID string
|
||||
Seq int64
|
||||
Payload []byte
|
||||
MsgType string
|
||||
ServerID string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// InsertUpdateHistoryBatch inserts update history entries in a single batch.
|
||||
// Uses ON CONFLICT DO NOTHING to make inserts idempotent.
|
||||
func (s *PostgresStore) InsertUpdateHistoryBatch(ctx context.Context, entries []UpdateHistoryEntry) error {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var sb strings.Builder
|
||||
sb.WriteString("INSERT INTO document_update_history (document_id, stream_id, seq, payload, msg_type, server_id, created_at) VALUES ")
|
||||
|
||||
args := make([]interface{}, 0, len(entries)*7)
|
||||
for i, e := range entries {
|
||||
if i > 0 {
|
||||
sb.WriteString(",")
|
||||
}
|
||||
base := i*7 + 1
|
||||
sb.WriteString(fmt.Sprintf("($%d,$%d,$%d,$%d,$%d,$%d,$%d)", base, base+1, base+2, base+3, base+4, base+5, base+6))
|
||||
msgType := sanitizeTextForDB(e.MsgType)
|
||||
serverID := sanitizeTextForDB(e.ServerID)
|
||||
args = append(args, e.DocumentID, e.StreamID, e.Seq, e.Payload, nullIfEmpty(msgType), nullIfEmpty(serverID), e.CreatedAt)
|
||||
}
|
||||
// Idempotent insert
|
||||
sb.WriteString(" ON CONFLICT (document_id, stream_id) DO NOTHING")
|
||||
|
||||
if _, err := s.db.ExecContext(ctx, sb.String(), args...); err != nil {
|
||||
return fmt.Errorf("failed to insert update history batch: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListUpdateHistoryAfterSeq returns updates with seq greater than afterSeq, ordered by seq.
|
||||
func (s *PostgresStore) ListUpdateHistoryAfterSeq(ctx context.Context, documentID uuid.UUID, afterSeq int64, limit int) ([]UpdateHistoryEntry, error) {
|
||||
if limit <= 0 {
|
||||
limit = 1000
|
||||
}
|
||||
query := `
|
||||
SELECT document_id, stream_id, seq, payload, COALESCE(msg_type, ''), COALESCE(server_id, ''), created_at
|
||||
FROM document_update_history
|
||||
WHERE document_id = $1 AND seq > $2
|
||||
ORDER BY seq ASC
|
||||
LIMIT $3
|
||||
`
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, query, documentID, afterSeq, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list update history: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []UpdateHistoryEntry
|
||||
for rows.Next() {
|
||||
var e UpdateHistoryEntry
|
||||
if err := rows.Scan(&e.DocumentID, &e.StreamID, &e.Seq, &e.Payload, &e.MsgType, &e.ServerID, &e.CreatedAt); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan update history: %w", err)
|
||||
}
|
||||
results = append(results, e)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// DeleteUpdateHistoryUpToSeq deletes updates with seq <= maxSeq for a document.
|
||||
func (s *PostgresStore) DeleteUpdateHistoryUpToSeq(ctx context.Context, documentID uuid.UUID, maxSeq int64) error {
|
||||
query := `
|
||||
DELETE FROM document_update_history
|
||||
WHERE document_id = $1 AND seq <= $2
|
||||
`
|
||||
if _, err := s.db.ExecContext(ctx, query, documentID, maxSeq); err != nil {
|
||||
return fmt.Errorf("failed to delete update history: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func nullIfEmpty(s string) interface{} {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func sanitizeTextForDB(s string) string {
|
||||
if s == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.IndexByte(s, 0) >= 0 {
|
||||
return ""
|
||||
}
|
||||
if !utf8.ValidString(s) {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
320
backend/internal/workers/update_persist_worker.go
Normal file
320
backend/internal/workers/update_persist_worker.go
Normal file
@@ -0,0 +1,320 @@
|
||||
package workers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
updatePersistGroupName = "update-persist-worker"
|
||||
updatePersistLockKey = "lock:update-persist-worker"
|
||||
updatePersistLockTTL = 30 * time.Second
|
||||
updatePersistTick = 2 * time.Second
|
||||
updateReadCount = 200
|
||||
updateReadBlock = -1 // negative → go-redis omits BLOCK clause → non-blocking
|
||||
updateBatchSize = 500
|
||||
updateSafeSeqLag = int64(1000)
|
||||
updateAutoClaimIdle = 30 * time.Second
|
||||
updateHeartbeatEvery = 30 * time.Second
|
||||
)
|
||||
|
||||
// StartUpdatePersistWorker persists Redis Stream updates into Postgres for recovery.
|
||||
func StartUpdatePersistWorker(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) {
|
||||
if msgBus == nil || dbStore == nil {
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
logWorker(logger, "Update persist worker panic",
|
||||
zap.Any("panic", r),
|
||||
zap.ByteString("stack", debug.Stack()))
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
acquired, err := msgBus.AcquireLock(ctx, updatePersistLockKey, updatePersistLockTTL)
|
||||
if err != nil {
|
||||
logWorker(logger, "Failed to acquire update persist worker lock", zap.Error(err))
|
||||
time.Sleep(updatePersistTick)
|
||||
return
|
||||
}
|
||||
if !acquired {
|
||||
time.Sleep(updatePersistTick)
|
||||
return
|
||||
}
|
||||
|
||||
logWorker(logger, "Update persist worker lock acquired", zap.String("server_id", serverID))
|
||||
runUpdatePersistWorker(ctx, msgBus, dbStore, logger, serverID)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
// If the worker exited (including panic), pause briefly before retry.
|
||||
time.Sleep(updatePersistTick)
|
||||
}
|
||||
}
|
||||
|
||||
func runUpdatePersistWorker(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) {
|
||||
ticker := time.NewTicker(updatePersistTick)
|
||||
defer ticker.Stop()
|
||||
|
||||
refreshTicker := time.NewTicker(updatePersistLockTTL / 2)
|
||||
defer refreshTicker.Stop()
|
||||
|
||||
heartbeatTicker := time.NewTicker(updateHeartbeatEvery)
|
||||
defer heartbeatTicker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
_ = msgBus.ReleaseLock(ctx, updatePersistLockKey)
|
||||
return
|
||||
case <-refreshTicker.C:
|
||||
ok, err := msgBus.RefreshLock(ctx, updatePersistLockKey, updatePersistLockTTL)
|
||||
if err != nil || !ok {
|
||||
logWorker(logger, "Update persist worker lock lost", zap.Error(err))
|
||||
_ = msgBus.ReleaseLock(ctx, updatePersistLockKey)
|
||||
return
|
||||
}
|
||||
case <-heartbeatTicker.C:
|
||||
logWorker(logger, "Update persist worker heartbeat", zap.String("server_id", serverID))
|
||||
case <-ticker.C:
|
||||
if err := processUpdatePersistence(ctx, msgBus, dbStore, logger, serverID); err != nil {
|
||||
logWorker(logger, "Update persist worker tick failed", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func processUpdatePersistence(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, serverID string) error {
|
||||
// Only process documents with recent stream activity (active in the last 60 seconds)
|
||||
cutoff := float64(time.Now().Add(-60 * time.Second).Unix())
|
||||
activeDocIDs, err := msgBus.ZRangeByScore(ctx, "active-streams", cutoff, float64(time.Now().Unix()))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get active streams: %w", err)
|
||||
}
|
||||
|
||||
// Prune stale entries older than 5 minutes (best-effort cleanup)
|
||||
stale := float64(time.Now().Add(-5 * time.Minute).Unix())
|
||||
if _, err := msgBus.ZRemRangeByScore(ctx, "active-streams", 0, stale); err != nil {
|
||||
logWorker(logger, "Failed to prune stale active-streams entries", zap.Error(err))
|
||||
}
|
||||
|
||||
for _, docIDStr := range activeDocIDs {
|
||||
docID, err := uuid.Parse(docIDStr)
|
||||
if err != nil {
|
||||
logWorker(logger, "Invalid document ID in active-streams", zap.String("doc_id", docIDStr))
|
||||
continue
|
||||
}
|
||||
|
||||
streamKey := "stream:" + docIDStr
|
||||
if err := ensureConsumerGroup(ctx, msgBus, streamKey, updatePersistGroupName); err != nil {
|
||||
logWorker(logger, "Failed to ensure update persist consumer group", zap.String("stream", streamKey), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
var ackIDs []string
|
||||
docEntries := make([]store.UpdateHistoryEntry, 0, updateBatchSize)
|
||||
|
||||
// First, try to claim idle pending messages (e.g., from previous crashes)
|
||||
claimed, _, err := msgBus.XAutoClaim(ctx, streamKey, updatePersistGroupName, serverID, updateAutoClaimIdle, "0-0", updateReadCount)
|
||||
if err != nil {
|
||||
logWorker(logger, "XAutoClaim failed", zap.String("stream", streamKey), zap.Error(err))
|
||||
} else if len(claimed) > 0 {
|
||||
collectStreamMessages(ctx, msgBus, dbStore, logger, docID, streamKey, claimed, &docEntries, &ackIDs)
|
||||
}
|
||||
|
||||
messages, err := msgBus.XReadGroup(ctx, updatePersistGroupName, serverID, []string{streamKey, ">"}, updateReadCount, updateReadBlock)
|
||||
if err != nil {
|
||||
logWorker(logger, "XReadGroup failed", zap.String("stream", streamKey), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
if len(messages) > 0 {
|
||||
collectStreamMessages(ctx, msgBus, dbStore, logger, docID, streamKey, messages, &docEntries, &ackIDs)
|
||||
}
|
||||
|
||||
if len(docEntries) > 0 {
|
||||
if err := dbStore.InsertUpdateHistoryBatch(ctx, docEntries); err != nil {
|
||||
logWorker(logger, "Failed to insert update history batch", zap.Error(err))
|
||||
// Skip ACK to retry on next tick
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if len(ackIDs) > 0 {
|
||||
if _, err := msgBus.XAck(ctx, streamKey, updatePersistGroupName, ackIDs...); err != nil {
|
||||
logWorker(logger, "XAck failed", zap.String("stream", streamKey), zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func collectStreamMessages(ctx context.Context, msgBus messagebus.MessageBus, dbStore *store.PostgresStore, logger *zap.Logger, documentID uuid.UUID, streamKey string, messages []messagebus.StreamMessage, docEntries *[]store.UpdateHistoryEntry, ackIDs *[]string) {
|
||||
for _, msg := range messages {
|
||||
msgType := getString(msg.Values["type"])
|
||||
switch msgType {
|
||||
case "update":
|
||||
payloadB64 := getString(msg.Values["yjs_payload"])
|
||||
payload, err := base64.StdEncoding.DecodeString(payloadB64)
|
||||
if err != nil {
|
||||
logWorker(logger, "Failed to decode update payload",
|
||||
zap.String("stream", streamKey),
|
||||
zap.String("stream_id", msg.ID),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
seq := parseInt64(msg.Values["seq"])
|
||||
msgType := normalizeMsgType(msg.Values["msg_type"])
|
||||
serverID := sanitizeText(getString(msg.Values["server_id"]))
|
||||
entry := store.UpdateHistoryEntry{
|
||||
DocumentID: documentID,
|
||||
StreamID: msg.ID,
|
||||
Seq: seq,
|
||||
Payload: payload,
|
||||
MsgType: msgType,
|
||||
ServerID: serverID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
*docEntries = append(*docEntries, entry)
|
||||
case "snapshot":
|
||||
seq := parseInt64(msg.Values["seq"])
|
||||
if seq > 0 {
|
||||
if err := dbStore.UpsertStreamCheckpoint(ctx, documentID, msg.ID, seq); err != nil {
|
||||
logWorker(logger, "Failed to upsert stream checkpoint from snapshot marker",
|
||||
zap.String("document_id", documentID.String()),
|
||||
zap.Error(err))
|
||||
}
|
||||
// Retention: prune DB history based on checkpoint (best-effort)
|
||||
maxSeq := seq - updateSafeSeqLag
|
||||
if maxSeq > 0 {
|
||||
if err := dbStore.DeleteUpdateHistoryUpToSeq(ctx, documentID, maxSeq); err != nil {
|
||||
logWorker(logger, "Failed to prune update history",
|
||||
zap.String("document_id", documentID.String()),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
// Trim Redis stream to avoid unbounded growth (best-effort)
|
||||
if _, err := msgBus.XTrimMinID(ctx, streamKey, msg.ID); err != nil {
|
||||
logWorker(logger, "Failed to trim Redis stream",
|
||||
zap.String("stream", streamKey),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
*ackIDs = append(*ackIDs, msg.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func ensureConsumerGroup(ctx context.Context, msgBus messagebus.MessageBus, streamKey, group string) error {
|
||||
if err := msgBus.XGroupCreateMkStream(ctx, streamKey, group, "0-0"); err != nil {
|
||||
if !isBusyGroup(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isBusyGroup(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(err.Error(), "BUSYGROUP")
|
||||
}
|
||||
|
||||
func getString(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
default:
|
||||
return fmt.Sprint(v)
|
||||
}
|
||||
}
|
||||
|
||||
func parseInt64(value interface{}) int64 {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case int:
|
||||
return int64(v)
|
||||
case uint64:
|
||||
return int64(v)
|
||||
case string:
|
||||
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return parsed
|
||||
}
|
||||
case []byte:
|
||||
if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func sanitizeText(s string) string {
|
||||
if s == "" {
|
||||
return s
|
||||
}
|
||||
if strings.IndexByte(s, 0) >= 0 {
|
||||
return ""
|
||||
}
|
||||
if !utf8.ValidString(s) {
|
||||
return ""
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func normalizeMsgType(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if v == "" {
|
||||
return ""
|
||||
}
|
||||
if len(v) == 1 {
|
||||
return strconv.Itoa(int(v[0]))
|
||||
}
|
||||
return sanitizeText(v)
|
||||
case []byte:
|
||||
if len(v) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(v) == 1 {
|
||||
return strconv.Itoa(int(v[0]))
|
||||
}
|
||||
return sanitizeText(string(v))
|
||||
default:
|
||||
return sanitizeText(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
|
||||
func logWorker(logger *zap.Logger, msg string, fields ...zap.Field) {
|
||||
if logger == nil {
|
||||
return
|
||||
}
|
||||
logger.Info(msg, fields...)
|
||||
}
|
||||
12
backend/scripts/010_add_stream_checkpoints.sql
Normal file
12
backend/scripts/010_add_stream_checkpoints.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
-- Migration: Add stream checkpoints table for Redis Streams durability
|
||||
-- This table tracks last processed stream position per document
|
||||
|
||||
CREATE TABLE IF NOT EXISTS stream_checkpoints (
|
||||
document_id UUID PRIMARY KEY REFERENCES documents(id) ON DELETE CASCADE,
|
||||
last_stream_id TEXT NOT NULL,
|
||||
last_seq BIGINT NOT NULL DEFAULT 0,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_stream_checkpoints_updated_at
|
||||
ON stream_checkpoints(updated_at DESC);
|
||||
22
backend/scripts/011_add_update_history.sql
Normal file
22
backend/scripts/011_add_update_history.sql
Normal file
@@ -0,0 +1,22 @@
|
||||
-- Migration: Add update history table for Redis Stream WAL
|
||||
-- This table stores per-update payloads for recovery and replay
|
||||
|
||||
CREATE TABLE IF NOT EXISTS document_update_history (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
document_id UUID NOT NULL REFERENCES documents(id) ON DELETE CASCADE,
|
||||
stream_id TEXT NOT NULL,
|
||||
seq BIGINT NOT NULL,
|
||||
payload BYTEA NOT NULL,
|
||||
msg_type TEXT,
|
||||
server_id TEXT,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_stream_id
|
||||
ON document_update_history(document_id, stream_id);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS uniq_update_history_document_seq
|
||||
ON document_update_history(document_id, seq);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_update_history_document_seq
|
||||
ON document_update_history(document_id, seq);
|
||||
Reference in New Issue
Block a user