diff --git a/backend/go.mod b/backend/go.mod index e0f3e91..48a0ce4 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -11,7 +11,9 @@ require ( github.com/gorilla/websocket v1.5.3 github.com/joho/godotenv v1.5.1 github.com/lib/pq v1.10.9 + github.com/redis/go-redis/v9 v9.17.3 github.com/stretchr/testify v1.11.1 + go.uber.org/zap v1.27.1 golang.org/x/oauth2 v0.34.0 ) @@ -19,8 +21,10 @@ require ( cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/bytedance/sonic v1.14.0 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/gabriel-vasile/mimetype v1.4.9 // indirect github.com/gin-contrib/sse v1.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect @@ -29,7 +33,6 @@ require ( github.com/goccy/go-yaml v1.18.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.3.0 // indirect - github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect @@ -41,6 +44,7 @@ require ( github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.3.0 // indirect go.uber.org/mock v0.5.0 // indirect + go.uber.org/multierr v1.10.0 // indirect golang.org/x/arch v0.20.0 // indirect golang.org/x/crypto v0.40.0 // indirect golang.org/x/mod v0.25.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 41d12c1..53363ca 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -1,15 +1,22 @@ cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M= github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= github.com/gin-contrib/cors v1.7.6 h1:3gQ8GMzs1Ylpf70y8bMw4fVpycXIeX1ZemuSQIsnQQY= @@ -68,6 +75,8 @@ github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg= github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= +github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4= +github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -83,8 +92,14 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA= github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU= go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= +go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c= golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk= golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= diff --git a/backend/internal/hub/hub.go b/backend/internal/hub/hub.go index feba3d5..fb15b97 100644 --- a/backend/internal/hub/hub.go +++ b/backend/internal/hub/hub.go @@ -1,18 +1,20 @@ package hub import ( - "log" + "context" "sync" "time" + "github.com/M1ngdaXie/realtime-collab/internal/messagebus" "github.com/google/uuid" "github.com/gorilla/websocket" + "go.uber.org/zap" ) type Message struct { - RoomID string - Data []byte - sender *Client + RoomID string + Data []byte + sender *Client } type Client struct { @@ -38,22 +40,34 @@ type Room struct { ID string clients map[*Client]bool mu sync.RWMutex - lastAwareness []byte // 存储最新的 awareness 消息,用于新用户加入时立即同步 + cancel context.CancelFunc } type Hub struct { - rooms map[string]*Room - mu sync.RWMutex - Register chan *Client // Exported - Unregister chan *Client // Exported - Broadcast chan *Message // Exported + rooms map[string]*Room + mu sync.RWMutex + Register chan *Client // Exported + Unregister chan *Client // Exported + Broadcast chan *Message // Exported + + //redis pub/sub config + messagebus messagebus.MessageBus + logger *zap.Logger + serverID string + fallbackMode bool } -func NewHub() *Hub { + +func NewHub(messagebus messagebus.MessageBus, serverID string, logger *zap.Logger) *Hub { return &Hub{ rooms: make(map[string]*Room), Register: make(chan *Client, 256), Unregister: make(chan *Client, 256), Broadcast: make(chan *Message, 4096), + // redis + messagebus: messagebus, + serverID: serverID, + logger: logger, + fallbackMode: false, // 默认 Redis 正常工作 } } @@ -75,300 +89,543 @@ func (h *Hub) registerClient(client *Client) { defer h.mu.Unlock() room, exists := h.rooms[client.roomID] + + // --- 1. 初始化房间 (仅针对该服务器上的第一个人) --- if !exists { room = &Room{ ID: client.roomID, clients: make(map[*Client]bool), + cancel: nil, + // lastAwareness 已被我们从结构体中物理删除或不再使用 } h.rooms[client.roomID] = room - log.Printf("Created new room with ID: %s", client.roomID) - } - room.mu.Lock() - room.clients[client] = true - // 获取现有的 awareness 数据(如果有的话) - awarenessData := room.lastAwareness - room.mu.Unlock() - log.Printf("Client %s joined room %s (total clients: %d)", client.ID, client.roomID, len(room.clients)) + h.logger.Info("Created new local room instance", zap.String("room_id", client.roomID)) - // 如果房间有之前的 awareness 状态,立即发送给新用户 - // 这样新用户不需要等待其他用户的下一次广播就能看到在线用户 - if len(awarenessData) > 0 { - select { - case client.send <- awarenessData: - log.Printf("📤 Sent existing awareness to new client %s", client.ID) - default: - log.Printf("⚠️ Failed to send awareness to new client %s (channel full)", client.ID) + // 开启跨服订阅 + if !h.fallbackMode && h.messagebus != nil { + ctx, cancel := context.WithCancel(context.Background()) + room.cancel = cancel + + msgChan, err := h.messagebus.Subscribe(ctx, client.roomID) + if err != nil { + h.logger.Error("Redis Subscribe failed", zap.Error(err)) + cancel() + room.cancel = nil + } else { + // 启动转发协程:确保以后别的服务器的消息能传给这台机器的人 + go h.startRoomMessageForwarding(ctx, client.roomID, msgChan) + } } } + + // --- 2. 将客户端加入本地房间列表 --- + room.mu.Lock() + room.clients[client] = true + room.mu.Unlock() + + h.logger.Info("Client linked to room", zap.String("client_id", client.ID)) + + // --- 3. 核心改进:立刻同步全量状态 --- + // 无论是不是第一个人,只要有人进来,我们就去 Redis 抓取所有人的状态发给他 + // hub/hub.go 内部的 registerClient 函数 + +// ... 之前的代码保持不变 ... + +if !h.fallbackMode && h.messagebus != nil { + go func(c *Client) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // 1. 从 Redis 抓取 + awarenessMap, err := h.messagebus.GetAllAwareness(ctx, c.roomID) + if err != nil { + h.logger.Error("Redis sync failed in goroutine", + zap.String("client_id", c.ID), + zap.Error(err)) + return + } + + if len(awarenessMap) == 0 { + h.logger.Debug("No awareness data found in Redis for sync", zap.String("room_id", c.roomID)) + return + } + + h.logger.Info("Starting state delivery to joiner", + zap.String("client_id", c.ID), + zap.Int("items", len(awarenessMap))) + + // 2. 逐条发送,带锁保护 + sentCount := 0 + for clientID, data := range awarenessMap { + c.sendMu.Lock() + // 🛑 核心防御:检查通道是否已被 unregisterClient 关闭 + if c.sendClosed { + c.sendMu.Unlock() + h.logger.Warn("Sync aborted: client channel closed while sending", + zap.String("client_id", c.ID), + zap.Uint64("target_yjs_id", clientID)) + return // 直接退出协程,不发了 + } + + select { + case c.send <- data: + sentCount++ + default: + // 缓冲区满了(通常是因为网络太卡),记录一条警告 + h.logger.Warn("Sync item skipped: client send buffer full", + zap.String("client_id", c.ID), + zap.Uint64("target_yjs_id", clientID)) + } + c.sendMu.Unlock() + } + + h.logger.Info("State sync completed successfully", + zap.String("client_id", c.ID), + zap.Int("delivered", sentCount)) + }(client) +} } func (h *Hub) unregisterClient(client *Client) { - h.mu.Lock() - // --------------------------------------------------- - // 注意:这里不要用 defer h.mu.Unlock(),因为我们要手动控制锁 - // --------------------------------------------------- + h.mu.Lock() + // 注意:这里不使用 defer,因为我们需要根据逻辑流手动释放锁,避免阻塞后续的 Redis 操作 - room, exists := h.rooms[client.roomID] - if !exists { - h.mu.Unlock() - return + room, exists := h.rooms[client.roomID] + if !exists { + h.mu.Unlock() + return + } + + // 1. 局部清理:从房间内移除客户端 + room.mu.Lock() + if _, ok := room.clients[client]; ok { + delete(room.clients, client) + + // 安全关闭客户端的发送管道 + client.sendMu.Lock() + if !client.sendClosed { + close(client.send) + client.sendClosed = true + } + client.sendMu.Unlock() + } + remainingClientsCount := len(room.clients) + room.mu.Unlock() + + h.logger.Info("Unregistered client from room", + zap.String("client_id", client.ID), + zap.String("room_id", client.roomID), + zap.Int("remaining_clients", remainingClientsCount), + ) + + // 2. 分布式清理:删除 Redis 中的感知数据 (Awareness) + if !h.fallbackMode && h.messagebus != nil { + go func() { + // 使用带超时的 Context 执行删除 + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + client.idsMu.Lock() + // 遍历该客户端在本机观察到的所有 Yjs ID + for clientID := range client.observedYjsIDs { + err := h.messagebus.DeleteAwareness(ctx, client.roomID, clientID) + h.logger.Info("DEBUG: IDs to cleanup", + zap.String("client_id", client.ID), + zap.Any("ids", client.observedYjsIDs)) + if err != nil { + h.logger.Warn("Failed to delete awareness from Redis", + zap.Uint64("yjs_id", clientID), + zap.Error(err), + ) + } + } + client.idsMu.Unlock() + }() + } + + // 3. 协作清理:发送“僵尸删除”消息给其他幸存者 + if remainingClientsCount > 0 { + client.idsMu.Lock() + clientClocks := make(map[uint64]uint64, len(client.observedYjsIDs)) + for id, clock := range client.observedYjsIDs { + clientClocks[id] = clock + } + client.idsMu.Unlock() + + if len(clientClocks) > 0 { + // 构造 Yjs 协议格式的删除消息 + deleteMsg := MakeYjsDeleteMessage(clientClocks) + msg := &Message{ + RoomID: client.roomID, + Data: deleteMsg, + sender: nil, // 系统发送 + } + + // 异步发送到广播通道,避免在持有 h.mu 时发生死锁 + go func() { + select { + case h.Broadcast <- msg: + case <-time.After(500 * time.Millisecond): + h.logger.Error("Critical: Failed to broadcast cleanup message (channel blocked)") + } + }() + } + } + + // 4. 房间清理:如果是最后一个人,彻底销毁房间资源 +if remainingClientsCount == 0 { + h.logger.Info("Room is empty, performing deep cleanup", zap.String("room_id", client.roomID)) + + // A. 停止转发协程 + if room.cancel != nil { + room.cancel() } - room.mu.Lock() // 锁住房间 - if _, ok := room.clients[client]; ok { - delete(room.clients, client) - - // 关闭连接通道 - client.sendMu.Lock() - if !client.sendClosed { - close(client.send) - client.sendClosed = true - } - client.sendMu.Unlock() - - log.Printf("Client disconnected: %s", client.ID) - } - - // 检查房间是否还有其他人 - remainingClientsCount := len(room.clients) - room.mu.Unlock() // 解锁房间 (我们已经删完人了) - - // --------------------------------------------------- - // [新增] 僵尸用户清理逻辑 (核心修改) - // --------------------------------------------------- - - // 只有当房间里还有其他人时,才需要广播通知 - if remainingClientsCount > 0 { - // 1. 从 client 的小本本里取出它用过的 Yjs ID 和对应的 clock - client.idsMu.Lock() - clientClocks := make(map[uint64]uint64, len(client.observedYjsIDs)) - for id, clock := range client.observedYjsIDs { - clientClocks[id] = clock - } - client.idsMu.Unlock() - - // 2. 如果有记录到的 ID,就伪造删除消息 (使用 clock+1) - if len(clientClocks) > 0 { - deleteMsg := MakeYjsDeleteMessage(clientClocks) // 调用工具函数,传入 clientID -> clock map - - log.Printf("🧹 Notifying others to remove Yjs IDs with clocks: %v", clientClocks) + // B. 分布式彻底清理 (关键改动!) + if !h.fallbackMode && h.messagebus != nil { + go func(rID string) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() - // 3. 广播给房间里的幸存者 - // 构造一个消息对象 - msg := &Message{ - RoomID: client.roomID, - Data: deleteMsg, - sender: nil, // sender 设为 nil,表示系统消息 + // ✨ 1. 直接删除整个 Redis Hash 表,不留任何死角 + if err := h.messagebus.ClearAllAwareness(ctx, rID); err != nil { + h.logger.Warn("Failed to clear total awareness from Redis", zap.Error(err)) } - - // !!特别注意!! - // 不要在这里直接调用 h.broadcastMessage(msg),因为那会尝试重新获取 h.mu 锁导致死锁 - // 我们直接把它扔到 Channel 里,让 Run() 去处理 - // 必须在一个非阻塞的 goroutine 里发,或者确保 channel 有缓冲 - - go func() { - // 使用 select 尝试发送,但如果满了,我们要稍微等一下,而不是直接丢弃 - // 因为这是“清理僵尸”的关键消息,丢了就会出 Bug - select { - case h.Broadcast <- msg: - // 发送成功 - case <-time.After(500 * time.Millisecond): - // 如果 500ms 还没塞进去,那说明系统真的挂了,只能丢弃并打印错误 - log.Printf("❌ Critical: Failed to broadcast cleanup message (Channel blocked)") - } - }() - } + + // 2. 取消订阅 + if err := h.messagebus.Unsubscribe(ctx, rID); err != nil { + h.logger.Warn("Failed to unsubscribe", zap.Error(err)) + } + }(client.roomID) } + // C. 从内存中移除 + delete(h.rooms, client.roomID) +} - // --------------------------------------------------- - // 结束清理逻辑 - // --------------------------------------------------- - - if remainingClientsCount == 0 { - delete(h.rooms, client.roomID) - log.Printf("Room destroyed: %s", client.roomID) - } - - h.mu.Unlock() // 最后解锁 Hub + h.mu.Unlock() // 手动释放 Hub 锁 } const ( - writeWait = 10 * time.Second - pongWait = 60 * time.Second - pingPeriod = (pongWait * 9) / 10 // 54 seconds + writeWait = 10 * time.Second + pongWait = 60 * time.Second + pingPeriod = (pongWait * 9) / 10 // 54 seconds maxSendFailures = 5 ) func (h *Hub) broadcastMessage(message *Message) { - h.mu.RLock() - room, exists := h.rooms[message.RoomID] - h.mu.RUnlock() - if !exists { - log.Printf("Room %s does not exist for broadcasting", message.RoomID) - return - } + h.mu.RLock() + room, exists := h.rooms[message.RoomID] + h.mu.RUnlock() - // 如果是 awareness 消息 (type=1),保存它供新用户使用 - if len(message.Data) > 0 && message.Data[0] == 1 { - room.mu.Lock() - room.lastAwareness = make([]byte, len(message.Data)) - copy(room.lastAwareness, message.Data) - room.mu.Unlock() - } + if !exists { + // 如果房间不存在,没必要继续 + return + } - room.mu.RLock() - defer room.mu.RUnlock() + // 1. 处理 Awareness 缓存 (Type 1) + // if len(message.Data) > 0 && message.Data[0] == 1 { + // room.mu.Lock() + // room.lastAwareness = make([]byte, len(message.Data)) + // copy(room.lastAwareness, message.Data) + // room.mu.Unlock() + // } - for client := range room.clients { - if client != message.sender { - select { - case client.send <- message.Data: - // Success - reset failure count - client.failureMu.Lock() - client.failureCount = 0 - client.failureMu.Unlock() + // 2. 本地广播 + h.broadcastToLocalClients(room, message.Data, message.sender) - default: - // Failed - increment failure count - client.failureMu.Lock() - client.failureCount++ - currentFailures := client.failureCount - client.failureMu.Unlock() + // 只有本地客户端发出的消息 (sender != nil) 才推送到 Redis + if message.sender != nil && !h.fallbackMode && h.messagebus != nil { + go func() { // 建议异步 Publish,不阻塞 Hub 的主循环 + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() - log.Printf("Failed to send to client %s (channel full, failures: %d/%d)", - client.ID, currentFailures, maxSendFailures) - - // Disconnect if threshold exceeded - if currentFailures >= maxSendFailures { - log.Printf("Client %s exceeded max send failures, disconnecting", client.ID) - go func(c *Client) { - c.unregister() - c.Conn.Close() - }(client) - } - } - } - } + err := h.messagebus.Publish(ctx, message.RoomID, message.Data) + if err != nil { + h.logger.Error("MessageBus publish failed", + zap.String("room_id", message.RoomID), + zap.Error(err), + ) + } + }() + } } +func (h *Hub) broadcastToLocalClients(room *Room, data []byte, sender *Client) { + room.mu.RLock() + defer room.mu.RUnlock() + // 2. 遍历房间内所有本地连接的客户端 + for client := range room.clients { + // 3. 排除发送者(如果是从 Redis 来的消息,sender 通常为 nil) + if client != sender { + select { + case client.send <- data: + // 发送成功:重置该客户端的失败计数 + client.failureMu.Lock() + client.failureCount = 0 + client.failureMu.Unlock() + + default: + + client.handleSendFailure() + } + } + } +} +func (h *Hub) startRoomMessageForwarding(ctx context.Context, roomID string, msgChan <-chan []byte) { + h.logger.Info("Starting message forwarding from Redis to room", + zap.String("room_id", roomID), + zap.String("server_id", h.serverID), + ) + + for { + select { + case <-ctx.Done(): + h.logger.Info("Stopping Redis message forwarding", zap.String("room_id", roomID)) + return + + case data, ok := <-msgChan: + if !ok { + h.logger.Warn("Redis message channel closed for room", zap.String("room_id", roomID)) + return + } + + // 1. 日志记录:记录消息类型和长度 + msgType := byte(0) + if len(data) > 0 { + msgType = data[0] + } + h.logger.Debug("Received message from Redis", + zap.String("room_id", roomID), + zap.Int("bytes", len(data)), + zap.Uint8("msg_type", msgType), + ) + + // 2. 获取本地房间对象 + h.mu.RLock() + room, exists := h.rooms[roomID] + h.mu.RUnlock() + + if !exists { + // 这种情况常见于:Redis 消息飞过来时,本地最后一个用户刚好断开删除了房间 + h.logger.Warn("Received Redis message but room does not exist locally", + zap.String("room_id", roomID), + ) + continue + } + + // 3. 广播给本地所有客户端 + // sender=nil 确保每个人(包括原本发送这条消息的人,如果他在本台机器上)都能收到 + h.broadcastToLocalClients(room, data, nil) + + // 4. 如果是感知数据 (Awareness),更新房间缓存 + // 这样后来加入的用户能立即看到其他人的光标 + // if len(data) > 0 && data[0] == 1 { + // room.mu.Lock() + // room.lastAwareness = make([]byte, len(data)) + // copy(room.lastAwareness, data) + // room.mu.Unlock() + // } + } + } +} + +// 辅助方法:处理发送失败逻辑,保持代码整洁 (DRY 原则) +func (c *Client) handleSendFailure() { + c.failureMu.Lock() + c.failureCount++ + currentFailures := c.failureCount + c.failureMu.Unlock() + + // 直接通过 c.hub 访问日志,代码非常自然 + c.hub.logger.Warn("Failed to send message to client", + zap.String("clientID", c.ID), + zap.String("roomID", c.roomID), + zap.Int("failures", currentFailures), + zap.Int("max", maxSendFailures), + ) + + if currentFailures >= maxSendFailures { + c.hub.logger.Error("Client exceeded max failures, disconnecting", + zap.String("clientID", c.ID)) + + // 这里的异步处理很正确,防止阻塞当前的广播循环 + go func() { + c.unregister() + c.Conn.Close() + }() + } +} +func (h *Hub) SetFallbackMode(fallback bool) { + h.mu.Lock() + // 1. 检查状态是否有变化(幂等性) + if h.fallbackMode == fallback { + h.mu.Unlock() + return + } + + // 2. 更新状态并记录日志 + h.fallbackMode = fallback + if fallback { + h.logger.Warn("Hub entering FALLBACK MODE (local-only communication)") + } else { + h.logger.Info("Hub restored to DISTRIBUTED MODE (cross-server sync enabled)") + } + + // 3. 如果是从备用模式恢复 (fallback=false),需要重新激活所有房间的 Redis 订阅 + if !fallback && h.messagebus != nil { + h.logger.Info("Recovering Redis subscriptions for existing rooms", zap.Int("room_count", len(h.rooms))) + + for roomID, room := range h.rooms { + // 先清理旧的订阅句柄(如果有残留) + if room.cancel != nil { + room.cancel() + room.cancel = nil + } + + // 为每个房间重新开启订阅 + ctx, cancel := context.WithCancel(context.Background()) + room.cancel = cancel + + // 尝试重新订阅 + msgChan, err := h.messagebus.Subscribe(ctx, roomID) + if err != nil { + h.logger.Error("Failed to re-subscribe during recovery", + zap.String("room_id", roomID), + zap.Error(err), + ) + cancel() + room.cancel = nil + continue + } + + // 重新启动转发协程 + go h.startRoomMessageForwarding(ctx, roomID, msgChan) + + h.logger.Debug("Successfully re-synced room with Redis", zap.String("room_id", roomID)) + } + } + h.mu.Unlock() +} func (c *Client) ReadPump() { - c.Conn.SetReadDeadline(time.Now().Add(pongWait)) - c.Conn.SetPongHandler(func(string) error { - c.Conn.SetReadDeadline(time.Now().Add(pongWait)) - return nil - }) - defer func() { - c.unregister() - c.Conn.Close() - }() - - for { - messageType, message, err := c.Conn.ReadMessage() - if err != nil { - if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Printf("error: %v", err) - } - break - } - if len(message) > 0 && message[0] == 1 { - log.Printf("DEBUG: 收到 Awareness (光标) 消息 from User %s Permission: %s", c.ID, c.Permission) -} - // ========================================================== - // 1. 偷听逻辑 (Sniff) - 必须放在转发之前! - // ========================================================== - if messageType == websocket.BinaryMessage && len(message) > 0 && message[0] == 1 { - clockMap := SniffYjsClientIDs(message) - if len(clockMap) > 0 { - c.idsMu.Lock() - for id, clock := range clockMap { - if clock > c.observedYjsIDs[id] { - log.Printf("🕵️ [Sniff] Client %s uses YjsID: %d (clock: %d)", c.ID, id, clock) - c.observedYjsIDs[id] = clock - } - } - c.idsMu.Unlock() - } - } + c.Conn.SetReadDeadline(time.Now().Add(pongWait)) + c.Conn.SetPongHandler(func(string) error { + c.Conn.SetReadDeadline(time.Now().Add(pongWait)) + return nil + }) + defer func() { + c.unregister() + c.Conn.Close() + }() - // ========================================================== - // 2. 权限检查 - 只有编辑权限的用户才能广播消息 - // ========================================================== - // ========================================================== - // 2. 权限检查 - 精细化拦截 (Fine-grained Permission Check) - // ========================================================== - if c.Permission == "view" { - // Yjs Protocol: - // message[0]: MessageType (0=Sync, 1=Awareness) - // message[1]: SyncMessageType (0=Step1, 1=Step2, 2=Update) - - // 只有当消息是 "Sync Update" (修改文档) 时,才拦截 - isSyncMessage := len(message) > 0 && message[0] == 0 - isUpdateOp := len(message) > 1 && message[1] == 2 + for { + messageType, message, err := c.Conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + c.hub.logger.Warn("Unexpected WebSocket close", + zap.String("client_id", c.ID), + zap.Error(err)) + } + break + } - if isSyncMessage && isUpdateOp { - log.Printf("🛡️ [Security] Blocked unauthorized WRITE from view-only user: %s", c.ID) - continue // ❌ 拦截修改 - } - - // ✅ 放行: - // 1. Awareness (光标): message[0] == 1 - // 2. SyncStep1/2 (握手加载文档): message[1] == 0 or 1 - } + // 1. Sniff Yjs client IDs from awareness messages (before broadcast) + if messageType == websocket.BinaryMessage && len(message) > 0 && message[0] == 1 { + clockMap := SniffYjsClientIDs(message) + if len(clockMap) > 0 { + c.idsMu.Lock() + for id, clock := range clockMap { + if clock > c.observedYjsIDs[id] { + c.hub.logger.Debug("Sniffed Yjs client ID", + zap.String("client_id", c.ID), + zap.Uint64("yjs_id", id), + zap.Uint64("clock", clock)) + c.observedYjsIDs[id] = clock + } + } + c.idsMu.Unlock() - // ========================================================== - // 3. 转发逻辑 (Broadcast) - 恢复协作功能 - // ========================================================== - if messageType == websocket.BinaryMessage { - // 注意:这里要检查 channel 是否已满,避免阻塞导致 ReadPump 卡死 - select { - case c.hub.Broadcast <- &Message{ - RoomID: c.roomID, - Data: message, - sender: c, - }: - // 发送成功 - default: - log.Printf("⚠️ Hub broadcast channel is full, dropping message from %s", c.ID) - } - } - } + // Cache awareness in Redis for cross-server sync + if !c.hub.fallbackMode && c.hub.messagebus != nil { + go func(cm map[uint64]uint64, msg []byte) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + for clientID := range cm { + if err := c.hub.messagebus.SetAwareness(ctx, c.roomID, clientID, msg); err != nil { + c.hub.logger.Warn("Failed to cache awareness in Redis", + zap.Uint64("yjs_id", clientID), + zap.Error(err)) + } + } + }(clockMap, message) + } + } + } + + // 2. Permission check - block write operations from view-only users + if c.Permission == "view" { + isSyncMessage := len(message) > 0 && message[0] == 0 + isUpdateOp := len(message) > 1 && message[1] == 2 + + if isSyncMessage && isUpdateOp { + c.hub.logger.Warn("Blocked write from view-only user", + zap.String("client_id", c.ID)) + continue + } + // Allow: Awareness (type=1), SyncStep1/2 (type=0, subtype=0/1) + } + + // 3. Broadcast to room + if messageType == websocket.BinaryMessage { + select { + case c.hub.Broadcast <- &Message{ + RoomID: c.roomID, + Data: message, + sender: c, + }: + default: + c.hub.logger.Warn("Hub broadcast channel full, dropping message", + zap.String("client_id", c.ID)) + } + } + } } func (c *Client) WritePump() { - ticker := time.NewTicker(pingPeriod) - defer func() { - ticker.Stop() - c.unregister() // NEW: Now WritePump also unregisters - c.Conn.Close() - }() + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.unregister() + c.Conn.Close() + }() - for { - select { - case message, ok := <-c.send: - c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) - if !ok { - // Hub closed the channel - c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) - return - } + for { + select { + case message, ok := <-c.send: + c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + c.Conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } - err := c.Conn.WriteMessage(websocket.BinaryMessage, message) - if err != nil { - log.Printf("Error writing message to client %s: %v", c.ID, err) - return - } + if err := c.Conn.WriteMessage(websocket.BinaryMessage, message); err != nil { + c.hub.logger.Warn("Error writing message to client", + zap.String("client_id", c.ID), + zap.Error(err)) + return + } - case <-ticker.C: - c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) - if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { - log.Printf("Ping failed for client %s: %v", c.ID, err) - return - } - } - } + case <-ticker.C: + c.Conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil { + c.hub.logger.Debug("Ping failed for client", + zap.String("client_id", c.ID), + zap.Error(err)) + return + } + } + } } - func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string, permission string, conn *websocket.Conn, hub *Hub, roomID string) *Client { return &Client{ ID: id, @@ -384,7 +641,7 @@ func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string } } func (c *Client) unregister() { - c.unregisterOnce.Do(func() { - c.hub.Unregister <- c - }) -} \ No newline at end of file + c.unregisterOnce.Do(func() { + c.hub.Unregister <- c + }) +} diff --git a/backend/internal/hub/yjs_protocol.go b/backend/internal/hub/yjs_protocol.go index 71479d0..10a16be 100644 --- a/backend/internal/hub/yjs_protocol.go +++ b/backend/internal/hub/yjs_protocol.go @@ -8,92 +8,106 @@ import ( // data: 前端发来的 []byte // return: map[clientID]clock func SniffYjsClientIDs(data []byte) map[uint64]uint64 { - // 简单的防御:如果不是 Type 1 (Awareness) 消息,直接忽略 - if len(data) < 2 || data[0] != 1 { - return nil - } + // 简单的防御:如果不是 Type 1 (Awareness) 消息,直接忽略 + if len(data) < 2 || data[0] != 1 { + return nil + } - result := make(map[uint64]uint64) - offset := 1 // 跳过 [0] MessageType + result := make(map[uint64]uint64) + offset := 1 // 跳过 [0] MessageType - // 读取总长度 (跳过) - _, n := binary.Uvarint(data[offset:]) - if n <= 0 { return nil } - offset += n + // 读取总长度 (跳过) + _, n := binary.Uvarint(data[offset:]) + if n <= 0 { + return nil + } + offset += n - // 读取 Count (包含几个客户端的信息) - count, n := binary.Uvarint(data[offset:]) - if n <= 0 { return nil } - offset += n + // 读取 Count (包含几个客户端的信息) + count, n := binary.Uvarint(data[offset:]) + if n <= 0 { + return nil + } + offset += n - // 循环读取每个客户端的信息 - for i := 0; i < int(count); i++ { - if offset >= len(data) { break } + // 循环读取每个客户端的信息 + for i := 0; i < int(count); i++ { + if offset >= len(data) { + break + } - // 1. 读取 ClientID - id, n := binary.Uvarint(data[offset:]) - if n <= 0 { break } - offset += n + // 1. 读取 ClientID + id, n := binary.Uvarint(data[offset:]) + if n <= 0 { + break + } + offset += n - // 2. 读取 Clock (现在我们需要保存它!) - clock, n := binary.Uvarint(data[offset:]) - if n <= 0 { break } - offset += n + // 2. 读取 Clock (现在我们需要保存它!) + clock, n := binary.Uvarint(data[offset:]) + if n <= 0 { + break + } + offset += n - // 保存 clientID -> clock - result[id] = clock + // 保存 clientID -> clock + result[id] = clock - // 3. 跳过 JSON String - if offset >= len(data) { break } - strLen, n := binary.Uvarint(data[offset:]) - if n <= 0 { break } - offset += n + // 3. 跳过 JSON String + if offset >= len(data) { + break + } + strLen, n := binary.Uvarint(data[offset:]) + if n <= 0 { + break + } + offset += n - // 跳过具体字符串内容 - offset += int(strLen) - } + // 跳过具体字符串内容 + offset += int(strLen) + } - return result + return result } // 这个函数用来伪造一条"删除消息" // 输入:clientClocks map[clientID]clock - 要删除的 ClientID 及其最后已知的 clock 值 // 输出:可以广播给前端的 []byte func MakeYjsDeleteMessage(clientClocks map[uint64]uint64) []byte { - if len(clientClocks) == 0 { - return nil - } + if len(clientClocks) == 0 { + return nil + } - // 构造 Payload (负载) - // 格式: [Count] [ID] [Clock] [StringLen] [String] ... - payload := make([]byte, 0) + // 构造 Payload (负载) + // 格式: [Count] [ID] [Clock] [StringLen] [String] ... + payload := make([]byte, 0) - // 写入 Count (变长整数) - buf := make([]byte, 10) - n := binary.PutUvarint(buf, uint64(len(clientClocks))) - payload = append(payload, buf[:n]...) + // 写入 Count (变长整数) + buf := make([]byte, 10) + n := binary.PutUvarint(buf, uint64(len(clientClocks))) + payload = append(payload, buf[:n]...) - for id, clock := range clientClocks { - // ClientID - n = binary.PutUvarint(buf, id) - payload = append(payload, buf[:n]...) - // Clock: 必须使用 clock + 1,这样 Yjs 才会接受这个更新! - n = binary.PutUvarint(buf, clock+1) - payload = append(payload, buf[:n]...) - // String Length (填 4,因为 "null" 长度是 4) - n = binary.PutUvarint(buf, 4) - payload = append(payload, buf[:n]...) - // String Content (这里是关键:null 代表删除用户) - payload = append(payload, []byte("null")...) - } + for id, clock := range clientClocks { + // ClientID + n = binary.PutUvarint(buf, id) + payload = append(payload, buf[:n]...) + // Clock: 必须使用 clock + 1,这样 Yjs 才会接受这个更新! + n = binary.PutUvarint(buf, clock+1) + payload = append(payload, buf[:n]...) + // String Length (填 4,因为 "null" 长度是 4) + n = binary.PutUvarint(buf, 4) + payload = append(payload, buf[:n]...) + // String Content (这里是关键:null 代表删除用户) + payload = append(payload, []byte("null")...) + } - // 构造最终消息: [Type=1] [PayloadLength] [Payload] - finalMsg := make([]byte, 0) - finalMsg = append(finalMsg, 1) // Type 1 + // 构造最终消息: [Type=1] [PayloadLength] [Payload] + finalMsg := make([]byte, 0) + finalMsg = append(finalMsg, 1) // Type 1 - n = binary.PutUvarint(buf, uint64(len(payload))) - finalMsg = append(finalMsg, buf[:n]...) // Length - finalMsg = append(finalMsg, payload...) // Body + n = binary.PutUvarint(buf, uint64(len(payload))) + finalMsg = append(finalMsg, buf[:n]...) // Length + finalMsg = append(finalMsg, payload...) // Body - return finalMsg -} \ No newline at end of file + return finalMsg +} diff --git a/backend/internal/logger/logger.go b/backend/internal/logger/logger.go new file mode 100644 index 0000000..b7a0bf6 --- /dev/null +++ b/backend/internal/logger/logger.go @@ -0,0 +1,41 @@ +package logger + +import ( + "os" + + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// NewLogger creates a production-grade logger with appropriate configuration +func NewLogger(isDevelopment bool) (*zap.Logger, error) { + var config zap.Config + + if isDevelopment { + config = zap.NewDevelopmentConfig() + config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder + } else { + config = zap.NewProductionConfig() + config.EncoderConfig.TimeKey = "timestamp" + config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder + } + + // Allow DEBUG level in development + if isDevelopment { + config.Level = zap.NewAtomicLevelAt(zapcore.DebugLevel) + } + + logger, err := config.Build() + if err != nil { + return nil, err + } + + return logger, nil +} + +// NewLoggerFromEnv creates logger based on environment +func NewLoggerFromEnv() (*zap.Logger, error) { + env := os.Getenv("ENVIRONMENT") + isDev := env == "" || env == "development" || env == "dev" + return NewLogger(isDev) +} diff --git a/backend/internal/messagebus/interface.go b/backend/internal/messagebus/interface.go new file mode 100644 index 0000000..8b4d025 --- /dev/null +++ b/backend/internal/messagebus/interface.go @@ -0,0 +1,80 @@ +package messagebus + +import ( + "context" +) + +// MessageBus abstracts message distribution across server instances +type MessageBus interface { + // Publish sends a message to a specific room channel + // data must be preserved as-is (binary safe) + Publish(ctx context.Context, roomID string, data []byte) error + + // Subscribe listens to messages for a specific room + // Returns a channel that receives binary messages + Subscribe(ctx context.Context, roomID string) (<-chan []byte, error) + + // Unsubscribe stops listening to a room + Unsubscribe(ctx context.Context, roomID string) error + + // SetAwareness caches awareness data for a client in a room + SetAwareness(ctx context.Context, roomID string, clientID uint64, data []byte) error + + // GetAllAwareness retrieves all cached awareness for a room + GetAllAwareness(ctx context.Context, roomID string) (map[uint64][]byte, error) + + // DeleteAwareness removes awareness cache for a client + DeleteAwareness(ctx context.Context, roomID string, clientID uint64) error + + ClearAllAwareness(ctx context.Context, roomID string) error + + // IsHealthy returns true if message bus is operational + IsHealthy() bool + + // Close gracefully shuts down the message bus + Close() error +} + +// LocalMessageBus is a no-op implementation for single-server mode +type LocalMessageBus struct{} + +func NewLocalMessageBus() *LocalMessageBus { + return &LocalMessageBus{} +} + +func (l *LocalMessageBus) Publish(ctx context.Context, roomID string, data []byte) error { + return nil // No-op for local mode +} + +func (l *LocalMessageBus) Subscribe(ctx context.Context, roomID string) (<-chan []byte, error) { + ch := make(chan []byte) + close(ch) // Immediately closed channel + return ch, nil +} + +func (l *LocalMessageBus) Unsubscribe(ctx context.Context, roomID string) error { + return nil +} + +func (l *LocalMessageBus) SetAwareness(ctx context.Context, roomID string, clientID uint64, data []byte) error { + return nil +} + +func (l *LocalMessageBus) GetAllAwareness(ctx context.Context, roomID string) (map[uint64][]byte, error) { + return nil, nil +} + +func (l *LocalMessageBus) DeleteAwareness(ctx context.Context, roomID string, clientID uint64) error { + return nil +} +func (l *LocalMessageBus) ClearAllAwareness(ctx context.Context, roomID string) error { + return nil +} + +func (l *LocalMessageBus) IsHealthy() bool { + return true +} + +func (l *LocalMessageBus) Close() error { + return nil +} diff --git a/backend/internal/messagebus/redis.go b/backend/internal/messagebus/redis.go new file mode 100644 index 0000000..a182a1f --- /dev/null +++ b/backend/internal/messagebus/redis.go @@ -0,0 +1,392 @@ +package messagebus + +import ( + "bytes" + "context" + "fmt" + "strconv" + "sync" + "time" + + goredis "github.com/redis/go-redis/v9" + "go.uber.org/zap" +) + +// envelopeSeparator separates the serverID prefix from the payload in Redis messages. +// This allows receivers to identify and skip messages they published themselves. +var envelopeSeparator = []byte{0xFF, 0x00} + +// RedisMessageBus implements MessageBus using Redis Pub/Sub +type RedisMessageBus struct { + client *goredis.Client + logger *zap.Logger + subscriptions map[string]*subscription // roomID -> subscription + subMu sync.RWMutex + serverID string +} + +type subscription struct { + pubsub *goredis.PubSub + channel chan []byte + cancel context.CancelFunc +} + +// NewRedisMessageBus creates a new Redis-backed message bus +func NewRedisMessageBus(redisURL string, serverID string, logger *zap.Logger) (*RedisMessageBus, error) { + opts, err := goredis.ParseURL(redisURL) + if err != nil { + logger.Error("Redis URL failed", + zap.String("url", redisURL), + zap.Error(err), + ) + return nil, err + } + client := goredis.NewClient(opts) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + logger.Info("Trying to reach Redis...", zap.String("addr", opts.Addr)) + + if err := client.Ping(ctx).Err(); err != nil { + logger.Error("Redis connection Ping Failed", + zap.String("addr", opts.Addr), + zap.Error(err), + ) + _ = client.Close() + return nil, err + } + + logger.Info("Redis connected successfully", zap.String("addr", opts.Addr)) + return &RedisMessageBus{ + client: client, + logger: logger, + subscriptions: make(map[string]*subscription), + serverID: serverID, + }, nil +} + +// Publish sends a binary message to a room channel, prepending the serverID envelope +func (r *RedisMessageBus) Publish(ctx context.Context, roomID string, data []byte) error { + channel := fmt.Sprintf("room:%s:messages", roomID) + + // Prepend serverID + separator so receivers can filter self-echoes + envelope := make([]byte, 0, len(r.serverID)+len(envelopeSeparator)+len(data)) + envelope = append(envelope, []byte(r.serverID)...) + envelope = append(envelope, envelopeSeparator...) + envelope = append(envelope, data...) + + err := r.client.Publish(ctx, channel, envelope).Err() + + if err != nil { + r.logger.Error("failed to publish message", + zap.String("roomID", roomID), + zap.Int("data_len", len(data)), + zap.String("channel", channel), + zap.Error(err), + ) + return fmt.Errorf("redis publish failed: %w", err) + } + r.logger.Debug("published message successfully", + zap.String("roomID", roomID), + zap.Int("data_len", len(data)), + ) + return nil +} + +// Subscribe creates a subscription to a room channel +func (r *RedisMessageBus) Subscribe(ctx context.Context, roomID string) (<-chan []byte, error) { + r.subMu.Lock() + defer r.subMu.Unlock() + + if sub, exists := r.subscriptions[roomID]; exists { + r.logger.Debug("returning existing subscription", zap.String("roomID", roomID)) + return sub.channel, nil + } + + // Subscribe to Redis channel + channel := fmt.Sprintf("room:%s:messages", roomID) + pubsub := r.client.Subscribe(ctx, channel) + + if _, err := pubsub.Receive(ctx); err != nil { + pubsub.Close() + return nil, fmt.Errorf("failed to verify subscription: %w", err) + } + + subCtx, cancel := context.WithCancel(context.Background()) + msgChan := make(chan []byte, 256) + sub := &subscription{ + pubsub: pubsub, + channel: msgChan, + cancel: cancel, + } + r.subscriptions[roomID] = sub + + go r.forwardMessages(subCtx, roomID, sub.pubsub, msgChan) + + r.logger.Info("successfully subscribed to room", + zap.String("roomID", roomID), + zap.String("channel", channel), + ) + return msgChan, nil +} + +// forwardMessages receives from Redis PubSub and forwards to local channel +func (r *RedisMessageBus) forwardMessages(ctx context.Context, roomID string, pubsub *goredis.PubSub, msgChan chan []byte) { + defer func() { + close(msgChan) + r.logger.Info("forwarder stopped", zap.String("roomID", roomID)) + }() + + //Get the Redis channel from pubsub + ch := pubsub.Channel() + + for { + select { + case <-ctx.Done(): + r.logger.Info("stopping the channel due to context cancellation", zap.String("roomID", roomID)) + return + + case msg, ok := <-ch: + // Check if channel is closed (!ok) + if !ok { + r.logger.Warn("redis pubsub channel closed unexpectedly", zap.String("roomID", roomID)) + return + } + + // Parse envelope: serverID + separator + payload + raw := []byte(msg.Payload) + sepIdx := bytes.Index(raw, envelopeSeparator) + if sepIdx == -1 { + r.logger.Warn("received message without server envelope, skipping", + zap.String("roomID", roomID)) + continue + } + + senderID := string(raw[:sepIdx]) + payload := raw[sepIdx+len(envelopeSeparator):] + + // Skip messages published by this same server (prevent echo) + if senderID == r.serverID { + continue + } + + select { + case msgChan <- payload: + r.logger.Debug("message forwarded", + zap.String("roomID", roomID), + zap.String("from_server", senderID), + zap.Int("size", len(payload))) + default: + r.logger.Warn("message dropped: consumer too slow", + zap.String("roomID", roomID)) + } + } + } +} + +// Unsubscribe stops listening to a room +func (r *RedisMessageBus) Unsubscribe(ctx context.Context, roomID string) error { + r.subMu.Lock() + defer r.subMu.Unlock() + + // Check if subscription exists + sub, ok := r.subscriptions[roomID] + if !ok { + r.logger.Debug("unsubscribe ignored: room not found", zap.String("roomID", roomID)) + return nil + } + // Cancel the context (stops forwardMessages goroutine) + sub.cancel() + + // Close the Redis pubsub connection + if err := sub.pubsub.Close(); err != nil { + r.logger.Error("failed to close redis pubsub", + zap.String("roomID", roomID), + zap.Error(err), + ) + } + // Remove from subscriptions map + delete(r.subscriptions, roomID) + r.logger.Info("successfully unsubscribed", zap.String("roomID", roomID)) + + return nil +} + +// SetAwareness caches awareness data in Redis Hash +func (r *RedisMessageBus) SetAwareness(ctx context.Context, roomID string, clientID uint64, data []byte) error { + key := fmt.Sprintf("room:%s:awareness", roomID) + field := fmt.Sprintf("%d", clientID) + + if err := r.client.HSet(ctx, key, field, data).Err(); err != nil { + r.logger.Error("failed to set awareness data", + zap.String("roomID", roomID), + zap.Uint64("clientID", clientID), + zap.Error(err), + ) + return fmt.Errorf("hset awareness failed: %w", err) + } + + // Set expiration on the Hash (30s) + if err := r.client.Expire(ctx, key, 30*time.Second).Err(); err != nil { + r.logger.Warn("failed to set expiration on awareness key", + zap.String("key", key), + zap.Error(err), + ) + } + r.logger.Debug("awareness updated", + zap.String("roomID", roomID), + zap.Uint64("clientID", clientID), + zap.Int("data_len", len(data)), + ) + + return nil +} + +// GetAllAwareness retrieves all cached awareness for a room +func (r *RedisMessageBus) GetAllAwareness(ctx context.Context, roomID string) (map[uint64][]byte, error) { + // 1. 构建 Redis Hash key + key := fmt.Sprintf("room:%s:awareness", roomID) + + // 2. 从 Redis 获取所有字段 + // HGetAll 会返回该 Hash 下所有的 field 和 value + result, err := r.client.HGetAll(ctx, key).Result() + if err != nil { + r.logger.Error("failed to HGetAll awareness", + zap.String("roomID", roomID), + zap.Error(err), + ) + return nil, fmt.Errorf("redis hgetall failed: %w", err) + } + + // 3. 转换数据类型:map[string]string -> map[uint64][]byte + awarenessMap := make(map[uint64][]byte, len(result)) + + for field, value := range result { + // 解析 field (clientID) 为 uint64 + // 虽然提示可以用 Sscanf,但在 Go 中 strconv.ParseUint 通常更高效且稳健 + clientID, err := strconv.ParseUint(field, 10, 64) + if err != nil { + r.logger.Warn("invalid clientID format in awareness hash", + zap.String("roomID", roomID), + zap.String("field", field), + ) + continue // 跳过异常字段,保证其他数据正常显示 + } + + // 将 string 转换为 []byte + awarenessMap[clientID] = []byte(value) + } + + // 4. 记录日志 + r.logger.Debug("retrieved all awareness data", + zap.String("roomID", roomID), + zap.Int("client_count", len(awarenessMap)), + ) + + return awarenessMap, nil +} + +// DeleteAwareness removes awareness cache for a client +func (r *RedisMessageBus) DeleteAwareness(ctx context.Context, roomID string, clientID uint64) error { + key := fmt.Sprintf("room:%s:awareness", roomID) + field := fmt.Sprintf("%d", clientID) + + if err := r.client.HDel(ctx, key, field).Err(); err != nil { + r.logger.Error("failed to delete awareness data", + zap.String("roomID", roomID), + zap.Uint64("clientID", clientID), + zap.Error(err), + ) + return fmt.Errorf("delete awareness failed: %w", err) + } + return nil +} + +// IsHealthy checks Redis connectivity +func (r *RedisMessageBus) IsHealthy() bool { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + // 只有 Ping 成功且没有报错,才认为服务是健康的 + if err := r.client.Ping(ctx).Err(); err != nil { + r.logger.Warn("Redis health check failed", zap.Error(err)) + return false + } + + return true +} + +// StartHealthMonitoring runs periodic health checks +func (r *RedisMessageBus) StartHealthMonitoring(ctx context.Context, interval time.Duration, onStatusChange func(bool)) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + previouslyHealthy := true + + for { + select { + case <-ctx.Done(): + r.logger.Info("stopping health monitoring") + return + case <-ticker.C: + // 检查当前健康状态 + currentlyHealthy := r.IsHealthy() + + // 如果状态发生变化(健康 -> 亚健康,或 亚健康 -> 恢复) + if currentlyHealthy != previouslyHealthy { + r.logger.Warn("Redis health status changed", + zap.Bool("old_status", previouslyHealthy), + zap.Bool("new_status", currentlyHealthy), + ) + + // 触发外部回调逻辑 + if onStatusChange != nil { + onStatusChange(currentlyHealthy) + } + + // 更新历史状态 + previouslyHealthy = currentlyHealthy + } + } + } +} + +func (r *RedisMessageBus) Close() error { + r.subMu.Lock() + defer r.subMu.Unlock() + + r.logger.Info("gracefully shutting down message bus", zap.Int("active_subs", len(r.subscriptions))) + + // 1. 关闭所有正在运行的订阅 + for roomID, sub := range r.subscriptions { + // 停止对应的 forwardMessages 协程 + sub.cancel() + + // 关闭物理连接 + if err := sub.pubsub.Close(); err != nil { + r.logger.Error("failed to close pubsub connection", + zap.String("roomID", roomID), + zap.Error(err), + ) + } + } + + // 2. 清空 Map,释放引用以便 GC 回收 + r.subscriptions = make(map[string]*subscription) + + // 3. 关闭主 Redis 客户端连接池 + if err := r.client.Close(); err != nil { + r.logger.Error("failed to close redis client", zap.Error(err)) + return err + } + + r.logger.Info("Redis message bus closed successfully") + return nil +} +// ClearAllAwareness 彻底删除该房间的感知数据 Hash +func (r *RedisMessageBus) ClearAllAwareness(ctx context.Context, roomID string) error { + key := fmt.Sprintf("room:%s:awareness", roomID) + // 直接使用 Del 命令删除整个 Key + return r.client.Del(ctx, key).Err() +}