feat: implement Redis Streams support with stream checkpoints and update history
- Added Redis Streams operations to the message bus interface and implementation. - Introduced StreamCheckpoint model to track last processed stream entry per document. - Implemented UpsertStreamCheckpoint and GetStreamCheckpoint methods in the Postgres store. - Created document_update_history table for storing update payloads for recovery and replay. - Developed update persist worker to handle Redis Stream updates and persist them to Postgres. - Enhanced Docker Compose configuration for Redis with persistence. - Updated frontend API to support fetching document state with optional share token. - Added connection stability monitoring in the Yjs document hook.
This commit is contained in:
@@ -1,22 +1,33 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type DocumentHandler struct {
|
||||
store *store.PostgresStore
|
||||
store *store.PostgresStore
|
||||
messageBus messagebus.MessageBus
|
||||
serverID string
|
||||
logger *zap.Logger
|
||||
}
|
||||
|
||||
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
|
||||
return &DocumentHandler{store: s}
|
||||
func NewDocumentHandler(s *store.PostgresStore, msgBus messagebus.MessageBus, serverID string, logger *zap.Logger) *DocumentHandler {
|
||||
return &DocumentHandler{
|
||||
store: s,
|
||||
messageBus: msgBus,
|
||||
serverID: serverID,
|
||||
logger: logger,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDocument creates a new document (requires auth)
|
||||
@@ -45,8 +56,6 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
|
||||
|
||||
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
|
||||
userID := auth.GetUserFromContext(c)
|
||||
fmt.Println("Getting userId, which is : ")
|
||||
fmt.Println(userID)
|
||||
if userID == nil {
|
||||
respondUnauthorized(c, "Authentication required to list documents")
|
||||
return
|
||||
@@ -113,6 +122,13 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
||||
}
|
||||
|
||||
userID := auth.GetUserFromContext(c)
|
||||
shareToken := c.Query("share")
|
||||
|
||||
doc, err := h.store.GetDocument(id)
|
||||
if err != nil {
|
||||
respondNotFound(c, "document")
|
||||
return
|
||||
}
|
||||
|
||||
// Check permission if authenticated
|
||||
if userID != nil {
|
||||
@@ -125,12 +141,22 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
|
||||
respondForbidden(c, "Access denied")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
doc, err := h.store.GetDocument(id)
|
||||
if err != nil {
|
||||
respondNotFound(c, "document")
|
||||
return
|
||||
} else {
|
||||
// Unauthenticated: require valid share token or public doc
|
||||
if shareToken != "" {
|
||||
valid, err := h.store.ValidateShareToken(c.Request.Context(), id, shareToken)
|
||||
if err != nil {
|
||||
respondInternalError(c, "Failed to validate share token", err)
|
||||
return
|
||||
}
|
||||
if !valid {
|
||||
respondForbidden(c, "Invalid or expired share token")
|
||||
return
|
||||
}
|
||||
} else if !doc.Is_Public {
|
||||
respondForbidden(c, "This document is not public. Please sign in to access.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Return empty byte slice if state is nil (new document)
|
||||
@@ -191,6 +217,16 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if streamID, seq, ok := h.addSnapshotMarker(c.Request.Context(), id); ok {
|
||||
if err := h.store.UpsertStreamCheckpoint(c.Request.Context(), id, streamID, seq); err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("Failed to upsert stream checkpoint after snapshot",
|
||||
zap.String("document_id", id.String()),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
|
||||
}
|
||||
|
||||
@@ -234,6 +270,43 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "Document deleted successfully"})
|
||||
}
|
||||
|
||||
func (h *DocumentHandler) addSnapshotMarker(ctx context.Context, documentID uuid.UUID) (string, int64, bool) {
|
||||
if h.messageBus == nil {
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
streamKey := "stream:" + documentID.String()
|
||||
seqKey := "seq:" + documentID.String()
|
||||
|
||||
seq, err := h.messageBus.Incr(ctx, seqKey)
|
||||
if err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("Failed to increment snapshot sequence",
|
||||
zap.String("document_id", documentID.String()),
|
||||
zap.Error(err))
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
|
||||
values := map[string]interface{}{
|
||||
"type": "snapshot",
|
||||
"server_id": h.serverID,
|
||||
"seq": seq,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
streamID, err := h.messageBus.XAdd(ctx, streamKey, 10000, true, values)
|
||||
if err != nil {
|
||||
if h.logger != nil {
|
||||
h.logger.Warn("Failed to add snapshot marker to stream",
|
||||
zap.String("stream_key", streamKey),
|
||||
zap.Error(err))
|
||||
}
|
||||
return "", 0, false
|
||||
}
|
||||
return streamID, seq, true
|
||||
}
|
||||
|
||||
// GetDocumentPermission returns the user's permission level for a document
|
||||
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
|
||||
documentID, err := uuid.Parse(c.Param("id"))
|
||||
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// DocumentHandlerSuite tests document CRUD operations
|
||||
@@ -23,7 +25,7 @@ type DocumentHandlerSuite struct {
|
||||
// SetupTest runs before each test
|
||||
func (s *DocumentHandlerSuite) SetupTest() {
|
||||
s.BaseHandlerSuite.SetupTest()
|
||||
s.handler = NewDocumentHandler(s.store)
|
||||
s.handler = NewDocumentHandler(s.store, messagebus.NewLocalMessageBus(), "test-server", zap.NewNop())
|
||||
s.setupRouter()
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ShareHandlerSuite tests for share handler endpoints
|
||||
@@ -24,7 +25,7 @@ func (s *ShareHandlerSuite) SetupTest() {
|
||||
s.BaseHandlerSuite.SetupTest()
|
||||
|
||||
// Create handler and router
|
||||
authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret)
|
||||
authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret, zap.NewNop())
|
||||
s.handler = NewShareHandler(s.store, s.cfg)
|
||||
s.router = gin.New()
|
||||
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/config"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/hub"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -19,16 +24,18 @@ import (
|
||||
var connectionSem = make(chan struct{}, 200)
|
||||
|
||||
type WebSocketHandler struct {
|
||||
hub *hub.Hub
|
||||
store store.Store
|
||||
cfg *config.Config
|
||||
hub *hub.Hub
|
||||
store store.Store
|
||||
cfg *config.Config
|
||||
msgBus messagebus.MessageBus
|
||||
}
|
||||
|
||||
func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config) *WebSocketHandler {
|
||||
func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config, msgBus messagebus.MessageBus) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
hub: h,
|
||||
store: s,
|
||||
cfg: cfg,
|
||||
hub: h,
|
||||
store: s,
|
||||
cfg: cfg,
|
||||
msgBus: msgBus,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,6 +177,105 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||
// Start goroutines
|
||||
go client.WritePump()
|
||||
go client.ReadPump()
|
||||
go wsh.replayBacklog(client, documentID)
|
||||
|
||||
log.Printf("Client connected: %s (user: %s) to room: %s", clientID, userName, roomID)
|
||||
}
|
||||
|
||||
const maxReplayUpdates = 5000
|
||||
|
||||
func (wsh *WebSocketHandler) replayBacklog(client *hub.Client, documentID uuid.UUID) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
checkpoint, err := wsh.store.GetStreamCheckpoint(ctx, documentID)
|
||||
if err != nil || checkpoint == nil || checkpoint.LastStreamID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
streamKey := "stream:" + documentID.String()
|
||||
var sent int
|
||||
|
||||
// Primary: Redis stream replay
|
||||
if wsh.msgBus != nil {
|
||||
messages, err := wsh.msgBus.XRange(ctx, streamKey, checkpoint.LastStreamID, "+")
|
||||
if err == nil && len(messages) > 0 {
|
||||
for _, msg := range messages {
|
||||
if msg.ID == checkpoint.LastStreamID {
|
||||
continue
|
||||
}
|
||||
if sent >= maxReplayUpdates {
|
||||
log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String())
|
||||
return
|
||||
}
|
||||
msgType := getString(msg.Values["type"])
|
||||
if msgType != "update" {
|
||||
continue
|
||||
}
|
||||
seq := parseInt64(msg.Values["seq"])
|
||||
if seq <= checkpoint.LastSeq {
|
||||
continue
|
||||
}
|
||||
payloadB64 := getString(msg.Values["yjs_payload"])
|
||||
payload, err := base64.StdEncoding.DecodeString(payloadB64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if client.Enqueue(payload) {
|
||||
sent++
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: DB history replay
|
||||
updates, err := wsh.store.ListUpdateHistoryAfterSeq(ctx, documentID, checkpoint.LastSeq, maxReplayUpdates)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, upd := range updates {
|
||||
if sent >= maxReplayUpdates {
|
||||
log.Printf("Replay capped at %d updates for doc %s", maxReplayUpdates, documentID.String())
|
||||
return
|
||||
}
|
||||
if client.Enqueue(upd.Payload) {
|
||||
sent++
|
||||
} else {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getString(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return v
|
||||
case []byte:
|
||||
return string(v)
|
||||
default:
|
||||
return fmt.Sprint(v)
|
||||
}
|
||||
}
|
||||
|
||||
func parseInt64(value interface{}) int64 {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return v
|
||||
case int:
|
||||
return int64(v)
|
||||
case uint64:
|
||||
return int64(v)
|
||||
case string:
|
||||
if parsed, err := strconv.ParseInt(v, 10, 64); err == nil {
|
||||
return parsed
|
||||
}
|
||||
case []byte:
|
||||
if parsed, err := strconv.ParseInt(string(v), 10, 64); err == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user