feat: implement Redis-based message bus and awareness management
- Added RedisMessageBus for message distribution using Redis Pub/Sub. - Introduced LocalMessageBus as a no-op implementation for single-server mode. - Created messagebus interface for abstraction of message distribution. - Implemented awareness management methods: SetAwareness, GetAllAwareness, DeleteAwareness, and ClearAllAwareness. - Added logger utility for structured logging with zap. - Refactored SniffYjsClientIDs and MakeYjsDeleteMessage functions for improved readability.
This commit is contained in:
@@ -11,7 +11,9 @@ require (
|
||||
github.com/gorilla/websocket v1.5.3
|
||||
github.com/joho/godotenv v1.5.1
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/redis/go-redis/v9 v9.17.3
|
||||
github.com/stretchr/testify v1.11.1
|
||||
go.uber.org/zap v1.27.1
|
||||
golang.org/x/oauth2 v0.34.0
|
||||
)
|
||||
|
||||
@@ -19,8 +21,10 @@ require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
github.com/bytedance/sonic v1.14.0 // indirect
|
||||
github.com/bytedance/sonic/loader v0.3.0 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.3.0 // indirect
|
||||
github.com/cloudwego/base64x v0.1.6 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.9 // indirect
|
||||
github.com/gin-contrib/sse v1.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
@@ -29,7 +33,6 @@ require (
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
|
||||
github.com/kr/text v0.2.0 // indirect
|
||||
github.com/leodido/go-urn v1.4.0 // indirect
|
||||
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
@@ -41,6 +44,7 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.3.0 // indirect
|
||||
go.uber.org/mock v0.5.0 // indirect
|
||||
go.uber.org/multierr v1.10.0 // indirect
|
||||
golang.org/x/arch v0.20.0 // indirect
|
||||
golang.org/x/crypto v0.40.0 // indirect
|
||||
golang.org/x/mod v0.25.0 // indirect
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs=
|
||||
github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c=
|
||||
github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA=
|
||||
github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0=
|
||||
github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ=
|
||||
github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA=
|
||||
github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA=
|
||||
github.com/bytedance/sonic/loader v0.3.0/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFosQU6FxH2JmUe6VI=
|
||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
|
||||
github.com/cloudwego/base64x v0.1.6 h1:t11wG9AECkCDk5fMSoxmufanudBtJ+/HemLstXDLI2M=
|
||||
github.com/cloudwego/base64x v0.1.6/go.mod h1:OFcloc187FXDaYHvrNIjxSe8ncn0OOM8gEHfghB2IPU=
|
||||
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY=
|
||||
github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok=
|
||||
github.com/gin-contrib/cors v1.7.6 h1:3gQ8GMzs1Ylpf70y8bMw4fVpycXIeX1ZemuSQIsnQQY=
|
||||
@@ -68,6 +75,8 @@ github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI=
|
||||
github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg=
|
||||
github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg=
|
||||
github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY=
|
||||
github.com/redis/go-redis/v9 v9.17.3 h1:fN29NdNrE17KttK5Ndf20buqfDZwGNgoUr9qjl1DQx4=
|
||||
github.com/redis/go-redis/v9 v9.17.3/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
@@ -83,8 +92,14 @@ github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
|
||||
github.com/ugorji/go/codec v1.3.0 h1:Qd2W2sQawAfG8XSvzwhBeoGq71zXOC/Q1E9y/wUcsUA=
|
||||
github.com/ugorji/go/codec v1.3.0/go.mod h1:pRBVtBSKl77K30Bv8R2P+cLSGaTtex6fsA2Wjqmfxj4=
|
||||
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
|
||||
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
|
||||
go.uber.org/mock v0.5.0 h1:KAMbZvZPyBPWgD14IrIQ38QCyjwpvVVV6K/bHl1IwQU=
|
||||
go.uber.org/mock v0.5.0/go.mod h1:ge71pBPLYDk7QIi1LupWxdAykm7KIEFchiOqd6z7qMM=
|
||||
go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ=
|
||||
go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
|
||||
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
|
||||
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
|
||||
golang.org/x/arch v0.20.0 h1:dx1zTU0MAE98U+TQ8BLl7XsJbgze2WnNKF/8tGp/Q6c=
|
||||
golang.org/x/arch v0.20.0/go.mod h1:bdwinDaKcfZUGpH09BB7ZmOfhalA8lQdzl62l8gGWsk=
|
||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package hub
|
||||
|
||||
import (
|
||||
"log"
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/messagebus"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
@@ -38,7 +40,7 @@ type Room struct {
|
||||
ID string
|
||||
clients map[*Client]bool
|
||||
mu sync.RWMutex
|
||||
lastAwareness []byte // 存储最新的 awareness 消息,用于新用户加入时立即同步
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
@@ -47,13 +49,25 @@ type Hub struct {
|
||||
Register chan *Client // Exported
|
||||
Unregister chan *Client // Exported
|
||||
Broadcast chan *Message // Exported
|
||||
|
||||
//redis pub/sub config
|
||||
messagebus messagebus.MessageBus
|
||||
logger *zap.Logger
|
||||
serverID string
|
||||
fallbackMode bool
|
||||
}
|
||||
func NewHub() *Hub {
|
||||
|
||||
func NewHub(messagebus messagebus.MessageBus, serverID string, logger *zap.Logger) *Hub {
|
||||
return &Hub{
|
||||
rooms: make(map[string]*Room),
|
||||
Register: make(chan *Client, 256),
|
||||
Unregister: make(chan *Client, 256),
|
||||
Broadcast: make(chan *Message, 4096),
|
||||
// redis
|
||||
messagebus: messagebus,
|
||||
serverID: serverID,
|
||||
logger: logger,
|
||||
fallbackMode: false, // 默认 Redis 正常工作
|
||||
}
|
||||
}
|
||||
|
||||
@@ -75,37 +89,105 @@ func (h *Hub) registerClient(client *Client) {
|
||||
defer h.mu.Unlock()
|
||||
|
||||
room, exists := h.rooms[client.roomID]
|
||||
|
||||
// --- 1. 初始化房间 (仅针对该服务器上的第一个人) ---
|
||||
if !exists {
|
||||
room = &Room{
|
||||
ID: client.roomID,
|
||||
clients: make(map[*Client]bool),
|
||||
cancel: nil,
|
||||
// lastAwareness 已被我们从结构体中物理删除或不再使用
|
||||
}
|
||||
h.rooms[client.roomID] = room
|
||||
log.Printf("Created new room with ID: %s", client.roomID)
|
||||
h.logger.Info("Created new local room instance", zap.String("room_id", client.roomID))
|
||||
|
||||
// 开启跨服订阅
|
||||
if !h.fallbackMode && h.messagebus != nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
room.cancel = cancel
|
||||
|
||||
msgChan, err := h.messagebus.Subscribe(ctx, client.roomID)
|
||||
if err != nil {
|
||||
h.logger.Error("Redis Subscribe failed", zap.Error(err))
|
||||
cancel()
|
||||
room.cancel = nil
|
||||
} else {
|
||||
// 启动转发协程:确保以后别的服务器的消息能传给这台机器的人
|
||||
go h.startRoomMessageForwarding(ctx, client.roomID, msgChan)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// --- 2. 将客户端加入本地房间列表 ---
|
||||
room.mu.Lock()
|
||||
room.clients[client] = true
|
||||
// 获取现有的 awareness 数据(如果有的话)
|
||||
awarenessData := room.lastAwareness
|
||||
room.mu.Unlock()
|
||||
log.Printf("Client %s joined room %s (total clients: %d)", client.ID, client.roomID, len(room.clients))
|
||||
|
||||
// 如果房间有之前的 awareness 状态,立即发送给新用户
|
||||
// 这样新用户不需要等待其他用户的下一次广播就能看到在线用户
|
||||
if len(awarenessData) > 0 {
|
||||
select {
|
||||
case client.send <- awarenessData:
|
||||
log.Printf("📤 Sent existing awareness to new client %s", client.ID)
|
||||
default:
|
||||
log.Printf("⚠️ Failed to send awareness to new client %s (channel full)", client.ID)
|
||||
h.logger.Info("Client linked to room", zap.String("client_id", client.ID))
|
||||
|
||||
// --- 3. 核心改进:立刻同步全量状态 ---
|
||||
// 无论是不是第一个人,只要有人进来,我们就去 Redis 抓取所有人的状态发给他
|
||||
// hub/hub.go 内部的 registerClient 函数
|
||||
|
||||
// ... 之前的代码保持不变 ...
|
||||
|
||||
if !h.fallbackMode && h.messagebus != nil {
|
||||
go func(c *Client) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 1. 从 Redis 抓取
|
||||
awarenessMap, err := h.messagebus.GetAllAwareness(ctx, c.roomID)
|
||||
if err != nil {
|
||||
h.logger.Error("Redis sync failed in goroutine",
|
||||
zap.String("client_id", c.ID),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if len(awarenessMap) == 0 {
|
||||
h.logger.Debug("No awareness data found in Redis for sync", zap.String("room_id", c.roomID))
|
||||
return
|
||||
}
|
||||
|
||||
h.logger.Info("Starting state delivery to joiner",
|
||||
zap.String("client_id", c.ID),
|
||||
zap.Int("items", len(awarenessMap)))
|
||||
|
||||
// 2. 逐条发送,带锁保护
|
||||
sentCount := 0
|
||||
for clientID, data := range awarenessMap {
|
||||
c.sendMu.Lock()
|
||||
// 🛑 核心防御:检查通道是否已被 unregisterClient 关闭
|
||||
if c.sendClosed {
|
||||
c.sendMu.Unlock()
|
||||
h.logger.Warn("Sync aborted: client channel closed while sending",
|
||||
zap.String("client_id", c.ID),
|
||||
zap.Uint64("target_yjs_id", clientID))
|
||||
return // 直接退出协程,不发了
|
||||
}
|
||||
|
||||
select {
|
||||
case c.send <- data:
|
||||
sentCount++
|
||||
default:
|
||||
// 缓冲区满了(通常是因为网络太卡),记录一条警告
|
||||
h.logger.Warn("Sync item skipped: client send buffer full",
|
||||
zap.String("client_id", c.ID),
|
||||
zap.Uint64("target_yjs_id", clientID))
|
||||
}
|
||||
c.sendMu.Unlock()
|
||||
}
|
||||
|
||||
h.logger.Info("State sync completed successfully",
|
||||
zap.String("client_id", c.ID),
|
||||
zap.Int("delivered", sentCount))
|
||||
}(client)
|
||||
}
|
||||
}
|
||||
func (h *Hub) unregisterClient(client *Client) {
|
||||
h.mu.Lock()
|
||||
// ---------------------------------------------------
|
||||
// 注意:这里不要用 defer h.mu.Unlock(),因为我们要手动控制锁
|
||||
// ---------------------------------------------------
|
||||
// 注意:这里不使用 defer,因为我们需要根据逻辑流手动释放锁,避免阻塞后续的 Redis 操作
|
||||
|
||||
room, exists := h.rooms[client.roomID]
|
||||
if !exists {
|
||||
@@ -113,32 +195,55 @@ func (h *Hub) unregisterClient(client *Client) {
|
||||
return
|
||||
}
|
||||
|
||||
room.mu.Lock() // 锁住房间
|
||||
// 1. 局部清理:从房间内移除客户端
|
||||
room.mu.Lock()
|
||||
if _, ok := room.clients[client]; ok {
|
||||
delete(room.clients, client)
|
||||
|
||||
// 关闭连接通道
|
||||
// 安全关闭客户端的发送管道
|
||||
client.sendMu.Lock()
|
||||
if !client.sendClosed {
|
||||
close(client.send)
|
||||
client.sendClosed = true
|
||||
}
|
||||
client.sendMu.Unlock()
|
||||
}
|
||||
remainingClientsCount := len(room.clients)
|
||||
room.mu.Unlock()
|
||||
|
||||
log.Printf("Client disconnected: %s", client.ID)
|
||||
h.logger.Info("Unregistered client from room",
|
||||
zap.String("client_id", client.ID),
|
||||
zap.String("room_id", client.roomID),
|
||||
zap.Int("remaining_clients", remainingClientsCount),
|
||||
)
|
||||
|
||||
// 2. 分布式清理:删除 Redis 中的感知数据 (Awareness)
|
||||
if !h.fallbackMode && h.messagebus != nil {
|
||||
go func() {
|
||||
// 使用带超时的 Context 执行删除
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
client.idsMu.Lock()
|
||||
// 遍历该客户端在本机观察到的所有 Yjs ID
|
||||
for clientID := range client.observedYjsIDs {
|
||||
err := h.messagebus.DeleteAwareness(ctx, client.roomID, clientID)
|
||||
h.logger.Info("DEBUG: IDs to cleanup",
|
||||
zap.String("client_id", client.ID),
|
||||
zap.Any("ids", client.observedYjsIDs))
|
||||
if err != nil {
|
||||
h.logger.Warn("Failed to delete awareness from Redis",
|
||||
zap.Uint64("yjs_id", clientID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
client.idsMu.Unlock()
|
||||
}()
|
||||
}
|
||||
|
||||
// 检查房间是否还有其他人
|
||||
remainingClientsCount := len(room.clients)
|
||||
room.mu.Unlock() // 解锁房间 (我们已经删完人了)
|
||||
|
||||
// ---------------------------------------------------
|
||||
// [新增] 僵尸用户清理逻辑 (核心修改)
|
||||
// ---------------------------------------------------
|
||||
|
||||
// 只有当房间里还有其他人时,才需要广播通知
|
||||
// 3. 协作清理:发送“僵尸删除”消息给其他幸存者
|
||||
if remainingClientsCount > 0 {
|
||||
// 1. 从 client 的小本本里取出它用过的 Yjs ID 和对应的 clock
|
||||
client.idsMu.Lock()
|
||||
clientClocks := make(map[uint64]uint64, len(client.observedYjsIDs))
|
||||
for id, clock := range client.observedYjsIDs {
|
||||
@@ -146,49 +251,57 @@ func (h *Hub) unregisterClient(client *Client) {
|
||||
}
|
||||
client.idsMu.Unlock()
|
||||
|
||||
// 2. 如果有记录到的 ID,就伪造删除消息 (使用 clock+1)
|
||||
if len(clientClocks) > 0 {
|
||||
deleteMsg := MakeYjsDeleteMessage(clientClocks) // 调用工具函数,传入 clientID -> clock map
|
||||
|
||||
log.Printf("🧹 Notifying others to remove Yjs IDs with clocks: %v", clientClocks)
|
||||
|
||||
// 3. 广播给房间里的幸存者
|
||||
// 构造一个消息对象
|
||||
// 构造 Yjs 协议格式的删除消息
|
||||
deleteMsg := MakeYjsDeleteMessage(clientClocks)
|
||||
msg := &Message{
|
||||
RoomID: client.roomID,
|
||||
Data: deleteMsg,
|
||||
sender: nil, // sender 设为 nil,表示系统消息
|
||||
sender: nil, // 系统发送
|
||||
}
|
||||
|
||||
// !!特别注意!!
|
||||
// 不要在这里直接调用 h.broadcastMessage(msg),因为那会尝试重新获取 h.mu 锁导致死锁
|
||||
// 我们直接把它扔到 Channel 里,让 Run() 去处理
|
||||
// 必须在一个非阻塞的 goroutine 里发,或者确保 channel 有缓冲
|
||||
|
||||
// 异步发送到广播通道,避免在持有 h.mu 时发生死锁
|
||||
go func() {
|
||||
// 使用 select 尝试发送,但如果满了,我们要稍微等一下,而不是直接丢弃
|
||||
// 因为这是“清理僵尸”的关键消息,丢了就会出 Bug
|
||||
select {
|
||||
case h.Broadcast <- msg:
|
||||
// 发送成功
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
// 如果 500ms 还没塞进去,那说明系统真的挂了,只能丢弃并打印错误
|
||||
log.Printf("❌ Critical: Failed to broadcast cleanup message (Channel blocked)")
|
||||
h.logger.Error("Critical: Failed to broadcast cleanup message (channel blocked)")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------
|
||||
// 结束清理逻辑
|
||||
// ---------------------------------------------------
|
||||
|
||||
// 4. 房间清理:如果是最后一个人,彻底销毁房间资源
|
||||
if remainingClientsCount == 0 {
|
||||
delete(h.rooms, client.roomID)
|
||||
log.Printf("Room destroyed: %s", client.roomID)
|
||||
h.logger.Info("Room is empty, performing deep cleanup", zap.String("room_id", client.roomID))
|
||||
|
||||
// A. 停止转发协程
|
||||
if room.cancel != nil {
|
||||
room.cancel()
|
||||
}
|
||||
|
||||
h.mu.Unlock() // 最后解锁 Hub
|
||||
// B. 分布式彻底清理 (关键改动!)
|
||||
if !h.fallbackMode && h.messagebus != nil {
|
||||
go func(rID string) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// ✨ 1. 直接删除整个 Redis Hash 表,不留任何死角
|
||||
if err := h.messagebus.ClearAllAwareness(ctx, rID); err != nil {
|
||||
h.logger.Warn("Failed to clear total awareness from Redis", zap.Error(err))
|
||||
}
|
||||
|
||||
// 2. 取消订阅
|
||||
if err := h.messagebus.Unsubscribe(ctx, rID); err != nil {
|
||||
h.logger.Warn("Failed to unsubscribe", zap.Error(err))
|
||||
}
|
||||
}(client.roomID)
|
||||
}
|
||||
// C. 从内存中移除
|
||||
delete(h.rooms, client.roomID)
|
||||
}
|
||||
|
||||
h.mu.Unlock() // 手动释放 Hub 锁
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -202,55 +315,197 @@ func (h *Hub) broadcastMessage(message *Message) {
|
||||
h.mu.RLock()
|
||||
room, exists := h.rooms[message.RoomID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
log.Printf("Room %s does not exist for broadcasting", message.RoomID)
|
||||
// 如果房间不存在,没必要继续
|
||||
return
|
||||
}
|
||||
|
||||
// 如果是 awareness 消息 (type=1),保存它供新用户使用
|
||||
if len(message.Data) > 0 && message.Data[0] == 1 {
|
||||
room.mu.Lock()
|
||||
room.lastAwareness = make([]byte, len(message.Data))
|
||||
copy(room.lastAwareness, message.Data)
|
||||
room.mu.Unlock()
|
||||
// 1. 处理 Awareness 缓存 (Type 1)
|
||||
// if len(message.Data) > 0 && message.Data[0] == 1 {
|
||||
// room.mu.Lock()
|
||||
// room.lastAwareness = make([]byte, len(message.Data))
|
||||
// copy(room.lastAwareness, message.Data)
|
||||
// room.mu.Unlock()
|
||||
// }
|
||||
|
||||
// 2. 本地广播
|
||||
h.broadcastToLocalClients(room, message.Data, message.sender)
|
||||
|
||||
// 只有本地客户端发出的消息 (sender != nil) 才推送到 Redis
|
||||
if message.sender != nil && !h.fallbackMode && h.messagebus != nil {
|
||||
go func() { // 建议异步 Publish,不阻塞 Hub 的主循环
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := h.messagebus.Publish(ctx, message.RoomID, message.Data)
|
||||
if err != nil {
|
||||
h.logger.Error("MessageBus publish failed",
|
||||
zap.String("room_id", message.RoomID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) broadcastToLocalClients(room *Room, data []byte, sender *Client) {
|
||||
room.mu.RLock()
|
||||
defer room.mu.RUnlock()
|
||||
|
||||
// 2. 遍历房间内所有本地连接的客户端
|
||||
for client := range room.clients {
|
||||
if client != message.sender {
|
||||
// 3. 排除发送者(如果是从 Redis 来的消息,sender 通常为 nil)
|
||||
if client != sender {
|
||||
select {
|
||||
case client.send <- message.Data:
|
||||
// Success - reset failure count
|
||||
case client.send <- data:
|
||||
// 发送成功:重置该客户端的失败计数
|
||||
client.failureMu.Lock()
|
||||
client.failureCount = 0
|
||||
client.failureMu.Unlock()
|
||||
|
||||
default:
|
||||
// Failed - increment failure count
|
||||
client.failureMu.Lock()
|
||||
client.failureCount++
|
||||
currentFailures := client.failureCount
|
||||
client.failureMu.Unlock()
|
||||
|
||||
log.Printf("Failed to send to client %s (channel full, failures: %d/%d)",
|
||||
client.ID, currentFailures, maxSendFailures)
|
||||
client.handleSendFailure()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
func (h *Hub) startRoomMessageForwarding(ctx context.Context, roomID string, msgChan <-chan []byte) {
|
||||
h.logger.Info("Starting message forwarding from Redis to room",
|
||||
zap.String("room_id", roomID),
|
||||
zap.String("server_id", h.serverID),
|
||||
)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
h.logger.Info("Stopping Redis message forwarding", zap.String("room_id", roomID))
|
||||
return
|
||||
|
||||
case data, ok := <-msgChan:
|
||||
if !ok {
|
||||
h.logger.Warn("Redis message channel closed for room", zap.String("room_id", roomID))
|
||||
return
|
||||
}
|
||||
|
||||
// 1. 日志记录:记录消息类型和长度
|
||||
msgType := byte(0)
|
||||
if len(data) > 0 {
|
||||
msgType = data[0]
|
||||
}
|
||||
h.logger.Debug("Received message from Redis",
|
||||
zap.String("room_id", roomID),
|
||||
zap.Int("bytes", len(data)),
|
||||
zap.Uint8("msg_type", msgType),
|
||||
)
|
||||
|
||||
// 2. 获取本地房间对象
|
||||
h.mu.RLock()
|
||||
room, exists := h.rooms[roomID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
// 这种情况常见于:Redis 消息飞过来时,本地最后一个用户刚好断开删除了房间
|
||||
h.logger.Warn("Received Redis message but room does not exist locally",
|
||||
zap.String("room_id", roomID),
|
||||
)
|
||||
continue
|
||||
}
|
||||
|
||||
// 3. 广播给本地所有客户端
|
||||
// sender=nil 确保每个人(包括原本发送这条消息的人,如果他在本台机器上)都能收到
|
||||
h.broadcastToLocalClients(room, data, nil)
|
||||
|
||||
// 4. 如果是感知数据 (Awareness),更新房间缓存
|
||||
// 这样后来加入的用户能立即看到其他人的光标
|
||||
// if len(data) > 0 && data[0] == 1 {
|
||||
// room.mu.Lock()
|
||||
// room.lastAwareness = make([]byte, len(data))
|
||||
// copy(room.lastAwareness, data)
|
||||
// room.mu.Unlock()
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助方法:处理发送失败逻辑,保持代码整洁 (DRY 原则)
|
||||
func (c *Client) handleSendFailure() {
|
||||
c.failureMu.Lock()
|
||||
c.failureCount++
|
||||
currentFailures := c.failureCount
|
||||
c.failureMu.Unlock()
|
||||
|
||||
// 直接通过 c.hub 访问日志,代码非常自然
|
||||
c.hub.logger.Warn("Failed to send message to client",
|
||||
zap.String("clientID", c.ID),
|
||||
zap.String("roomID", c.roomID),
|
||||
zap.Int("failures", currentFailures),
|
||||
zap.Int("max", maxSendFailures),
|
||||
)
|
||||
|
||||
// Disconnect if threshold exceeded
|
||||
if currentFailures >= maxSendFailures {
|
||||
log.Printf("Client %s exceeded max send failures, disconnecting", client.ID)
|
||||
go func(c *Client) {
|
||||
c.hub.logger.Error("Client exceeded max failures, disconnecting",
|
||||
zap.String("clientID", c.ID))
|
||||
|
||||
// 这里的异步处理很正确,防止阻塞当前的广播循环
|
||||
go func() {
|
||||
c.unregister()
|
||||
c.Conn.Close()
|
||||
}(client)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
func (h *Hub) SetFallbackMode(fallback bool) {
|
||||
h.mu.Lock()
|
||||
// 1. 检查状态是否有变化(幂等性)
|
||||
if h.fallbackMode == fallback {
|
||||
h.mu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// 2. 更新状态并记录日志
|
||||
h.fallbackMode = fallback
|
||||
if fallback {
|
||||
h.logger.Warn("Hub entering FALLBACK MODE (local-only communication)")
|
||||
} else {
|
||||
h.logger.Info("Hub restored to DISTRIBUTED MODE (cross-server sync enabled)")
|
||||
}
|
||||
|
||||
// 3. 如果是从备用模式恢复 (fallback=false),需要重新激活所有房间的 Redis 订阅
|
||||
if !fallback && h.messagebus != nil {
|
||||
h.logger.Info("Recovering Redis subscriptions for existing rooms", zap.Int("room_count", len(h.rooms)))
|
||||
|
||||
for roomID, room := range h.rooms {
|
||||
// 先清理旧的订阅句柄(如果有残留)
|
||||
if room.cancel != nil {
|
||||
room.cancel()
|
||||
room.cancel = nil
|
||||
}
|
||||
|
||||
// 为每个房间重新开启订阅
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
room.cancel = cancel
|
||||
|
||||
// 尝试重新订阅
|
||||
msgChan, err := h.messagebus.Subscribe(ctx, roomID)
|
||||
if err != nil {
|
||||
h.logger.Error("Failed to re-subscribe during recovery",
|
||||
zap.String("room_id", roomID),
|
||||
zap.Error(err),
|
||||
)
|
||||
cancel()
|
||||
room.cancel = nil
|
||||
continue
|
||||
}
|
||||
|
||||
// 重新启动转发协程
|
||||
go h.startRoomMessageForwarding(ctx, roomID, msgChan)
|
||||
|
||||
h.logger.Debug("Successfully re-synced room with Redis", zap.String("room_id", roomID))
|
||||
}
|
||||
}
|
||||
h.mu.Unlock()
|
||||
}
|
||||
func (c *Client) ReadPump() {
|
||||
c.Conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
c.Conn.SetPongHandler(func(string) error {
|
||||
@@ -266,69 +521,70 @@ func (c *Client) ReadPump() {
|
||||
messageType, message, err := c.Conn.ReadMessage()
|
||||
if err != nil {
|
||||
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
||||
log.Printf("error: %v", err)
|
||||
c.hub.logger.Warn("Unexpected WebSocket close",
|
||||
zap.String("client_id", c.ID),
|
||||
zap.Error(err))
|
||||
}
|
||||
break
|
||||
}
|
||||
if len(message) > 0 && message[0] == 1 {
|
||||
log.Printf("DEBUG: 收到 Awareness (光标) 消息 from User %s Permission: %s", c.ID, c.Permission)
|
||||
}
|
||||
// ==========================================================
|
||||
// 1. 偷听逻辑 (Sniff) - 必须放在转发之前!
|
||||
// ==========================================================
|
||||
|
||||
// 1. Sniff Yjs client IDs from awareness messages (before broadcast)
|
||||
if messageType == websocket.BinaryMessage && len(message) > 0 && message[0] == 1 {
|
||||
clockMap := SniffYjsClientIDs(message)
|
||||
if len(clockMap) > 0 {
|
||||
c.idsMu.Lock()
|
||||
for id, clock := range clockMap {
|
||||
if clock > c.observedYjsIDs[id] {
|
||||
log.Printf("🕵️ [Sniff] Client %s uses YjsID: %d (clock: %d)", c.ID, id, clock)
|
||||
c.hub.logger.Debug("Sniffed Yjs client ID",
|
||||
zap.String("client_id", c.ID),
|
||||
zap.Uint64("yjs_id", id),
|
||||
zap.Uint64("clock", clock))
|
||||
c.observedYjsIDs[id] = clock
|
||||
}
|
||||
}
|
||||
c.idsMu.Unlock()
|
||||
|
||||
// Cache awareness in Redis for cross-server sync
|
||||
if !c.hub.fallbackMode && c.hub.messagebus != nil {
|
||||
go func(cm map[uint64]uint64, msg []byte) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
for clientID := range cm {
|
||||
if err := c.hub.messagebus.SetAwareness(ctx, c.roomID, clientID, msg); err != nil {
|
||||
c.hub.logger.Warn("Failed to cache awareness in Redis",
|
||||
zap.Uint64("yjs_id", clientID),
|
||||
zap.Error(err))
|
||||
}
|
||||
}
|
||||
}(clockMap, message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ==========================================================
|
||||
// 2. 权限检查 - 只有编辑权限的用户才能广播消息
|
||||
// ==========================================================
|
||||
// ==========================================================
|
||||
// 2. 权限检查 - 精细化拦截 (Fine-grained Permission Check)
|
||||
// ==========================================================
|
||||
// 2. Permission check - block write operations from view-only users
|
||||
if c.Permission == "view" {
|
||||
// Yjs Protocol:
|
||||
// message[0]: MessageType (0=Sync, 1=Awareness)
|
||||
// message[1]: SyncMessageType (0=Step1, 1=Step2, 2=Update)
|
||||
|
||||
// 只有当消息是 "Sync Update" (修改文档) 时,才拦截
|
||||
isSyncMessage := len(message) > 0 && message[0] == 0
|
||||
isUpdateOp := len(message) > 1 && message[1] == 2
|
||||
|
||||
if isSyncMessage && isUpdateOp {
|
||||
log.Printf("🛡️ [Security] Blocked unauthorized WRITE from view-only user: %s", c.ID)
|
||||
continue // ❌ 拦截修改
|
||||
c.hub.logger.Warn("Blocked write from view-only user",
|
||||
zap.String("client_id", c.ID))
|
||||
continue
|
||||
}
|
||||
// Allow: Awareness (type=1), SyncStep1/2 (type=0, subtype=0/1)
|
||||
}
|
||||
|
||||
// ✅ 放行:
|
||||
// 1. Awareness (光标): message[0] == 1
|
||||
// 2. SyncStep1/2 (握手加载文档): message[1] == 0 or 1
|
||||
}
|
||||
|
||||
// ==========================================================
|
||||
// 3. 转发逻辑 (Broadcast) - 恢复协作功能
|
||||
// ==========================================================
|
||||
// 3. Broadcast to room
|
||||
if messageType == websocket.BinaryMessage {
|
||||
// 注意:这里要检查 channel 是否已满,避免阻塞导致 ReadPump 卡死
|
||||
select {
|
||||
case c.hub.Broadcast <- &Message{
|
||||
RoomID: c.roomID,
|
||||
Data: message,
|
||||
sender: c,
|
||||
}:
|
||||
// 发送成功
|
||||
default:
|
||||
log.Printf("⚠️ Hub broadcast channel is full, dropping message from %s", c.ID)
|
||||
c.hub.logger.Warn("Hub broadcast channel full, dropping message",
|
||||
zap.String("client_id", c.ID))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -338,7 +594,7 @@ func (c *Client) WritePump() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.unregister() // NEW: Now WritePump also unregisters
|
||||
c.unregister()
|
||||
c.Conn.Close()
|
||||
}()
|
||||
|
||||
@@ -347,28 +603,29 @@ func (c *Client) WritePump() {
|
||||
case message, ok := <-c.send:
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if !ok {
|
||||
// Hub closed the channel
|
||||
c.Conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
err := c.Conn.WriteMessage(websocket.BinaryMessage, message)
|
||||
if err != nil {
|
||||
log.Printf("Error writing message to client %s: %v", c.ID, err)
|
||||
if err := c.Conn.WriteMessage(websocket.BinaryMessage, message); err != nil {
|
||||
c.hub.logger.Warn("Error writing message to client",
|
||||
zap.String("client_id", c.ID),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
case <-ticker.C:
|
||||
c.Conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := c.Conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
log.Printf("Ping failed for client %s: %v", c.ID, err)
|
||||
c.hub.logger.Debug("Ping failed for client",
|
||||
zap.String("client_id", c.ID),
|
||||
zap.Error(err))
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
func NewClient(id string, userID *uuid.UUID, userName string, userAvatar *string, permission string, conn *websocket.Conn, hub *Hub, roomID string) *Client {
|
||||
return &Client{
|
||||
ID: id,
|
||||
|
||||
@@ -18,35 +18,49 @@ func SniffYjsClientIDs(data []byte) map[uint64]uint64 {
|
||||
|
||||
// 读取总长度 (跳过)
|
||||
_, n := binary.Uvarint(data[offset:])
|
||||
if n <= 0 { return nil }
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
offset += n
|
||||
|
||||
// 读取 Count (包含几个客户端的信息)
|
||||
count, n := binary.Uvarint(data[offset:])
|
||||
if n <= 0 { return nil }
|
||||
if n <= 0 {
|
||||
return nil
|
||||
}
|
||||
offset += n
|
||||
|
||||
// 循环读取每个客户端的信息
|
||||
for i := 0; i < int(count); i++ {
|
||||
if offset >= len(data) { break }
|
||||
if offset >= len(data) {
|
||||
break
|
||||
}
|
||||
|
||||
// 1. 读取 ClientID
|
||||
id, n := binary.Uvarint(data[offset:])
|
||||
if n <= 0 { break }
|
||||
if n <= 0 {
|
||||
break
|
||||
}
|
||||
offset += n
|
||||
|
||||
// 2. 读取 Clock (现在我们需要保存它!)
|
||||
clock, n := binary.Uvarint(data[offset:])
|
||||
if n <= 0 { break }
|
||||
if n <= 0 {
|
||||
break
|
||||
}
|
||||
offset += n
|
||||
|
||||
// 保存 clientID -> clock
|
||||
result[id] = clock
|
||||
|
||||
// 3. 跳过 JSON String
|
||||
if offset >= len(data) { break }
|
||||
if offset >= len(data) {
|
||||
break
|
||||
}
|
||||
strLen, n := binary.Uvarint(data[offset:])
|
||||
if n <= 0 { break }
|
||||
if n <= 0 {
|
||||
break
|
||||
}
|
||||
offset += n
|
||||
|
||||
// 跳过具体字符串内容
|
||||
|
||||
41
backend/internal/logger/logger.go
Normal file
41
backend/internal/logger/logger.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// NewLogger creates a production-grade logger with appropriate configuration
|
||||
func NewLogger(isDevelopment bool) (*zap.Logger, error) {
|
||||
var config zap.Config
|
||||
|
||||
if isDevelopment {
|
||||
config = zap.NewDevelopmentConfig()
|
||||
config.EncoderConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||
} else {
|
||||
config = zap.NewProductionConfig()
|
||||
config.EncoderConfig.TimeKey = "timestamp"
|
||||
config.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
}
|
||||
|
||||
// Allow DEBUG level in development
|
||||
if isDevelopment {
|
||||
config.Level = zap.NewAtomicLevelAt(zapcore.DebugLevel)
|
||||
}
|
||||
|
||||
logger, err := config.Build()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return logger, nil
|
||||
}
|
||||
|
||||
// NewLoggerFromEnv creates logger based on environment
|
||||
func NewLoggerFromEnv() (*zap.Logger, error) {
|
||||
env := os.Getenv("ENVIRONMENT")
|
||||
isDev := env == "" || env == "development" || env == "dev"
|
||||
return NewLogger(isDev)
|
||||
}
|
||||
80
backend/internal/messagebus/interface.go
Normal file
80
backend/internal/messagebus/interface.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package messagebus
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// MessageBus abstracts message distribution across server instances
|
||||
type MessageBus interface {
|
||||
// Publish sends a message to a specific room channel
|
||||
// data must be preserved as-is (binary safe)
|
||||
Publish(ctx context.Context, roomID string, data []byte) error
|
||||
|
||||
// Subscribe listens to messages for a specific room
|
||||
// Returns a channel that receives binary messages
|
||||
Subscribe(ctx context.Context, roomID string) (<-chan []byte, error)
|
||||
|
||||
// Unsubscribe stops listening to a room
|
||||
Unsubscribe(ctx context.Context, roomID string) error
|
||||
|
||||
// SetAwareness caches awareness data for a client in a room
|
||||
SetAwareness(ctx context.Context, roomID string, clientID uint64, data []byte) error
|
||||
|
||||
// GetAllAwareness retrieves all cached awareness for a room
|
||||
GetAllAwareness(ctx context.Context, roomID string) (map[uint64][]byte, error)
|
||||
|
||||
// DeleteAwareness removes awareness cache for a client
|
||||
DeleteAwareness(ctx context.Context, roomID string, clientID uint64) error
|
||||
|
||||
ClearAllAwareness(ctx context.Context, roomID string) error
|
||||
|
||||
// IsHealthy returns true if message bus is operational
|
||||
IsHealthy() bool
|
||||
|
||||
// Close gracefully shuts down the message bus
|
||||
Close() error
|
||||
}
|
||||
|
||||
// LocalMessageBus is a no-op implementation for single-server mode
|
||||
type LocalMessageBus struct{}
|
||||
|
||||
func NewLocalMessageBus() *LocalMessageBus {
|
||||
return &LocalMessageBus{}
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) Publish(ctx context.Context, roomID string, data []byte) error {
|
||||
return nil // No-op for local mode
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) Subscribe(ctx context.Context, roomID string) (<-chan []byte, error) {
|
||||
ch := make(chan []byte)
|
||||
close(ch) // Immediately closed channel
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) Unsubscribe(ctx context.Context, roomID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) SetAwareness(ctx context.Context, roomID string, clientID uint64, data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) GetAllAwareness(ctx context.Context, roomID string) (map[uint64][]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) DeleteAwareness(ctx context.Context, roomID string, clientID uint64) error {
|
||||
return nil
|
||||
}
|
||||
func (l *LocalMessageBus) ClearAllAwareness(ctx context.Context, roomID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) IsHealthy() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *LocalMessageBus) Close() error {
|
||||
return nil
|
||||
}
|
||||
392
backend/internal/messagebus/redis.go
Normal file
392
backend/internal/messagebus/redis.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package messagebus
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
goredis "github.com/redis/go-redis/v9"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// envelopeSeparator separates the serverID prefix from the payload in Redis messages.
|
||||
// This allows receivers to identify and skip messages they published themselves.
|
||||
var envelopeSeparator = []byte{0xFF, 0x00}
|
||||
|
||||
// RedisMessageBus implements MessageBus using Redis Pub/Sub
|
||||
type RedisMessageBus struct {
|
||||
client *goredis.Client
|
||||
logger *zap.Logger
|
||||
subscriptions map[string]*subscription // roomID -> subscription
|
||||
subMu sync.RWMutex
|
||||
serverID string
|
||||
}
|
||||
|
||||
type subscription struct {
|
||||
pubsub *goredis.PubSub
|
||||
channel chan []byte
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewRedisMessageBus creates a new Redis-backed message bus
|
||||
func NewRedisMessageBus(redisURL string, serverID string, logger *zap.Logger) (*RedisMessageBus, error) {
|
||||
opts, err := goredis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
logger.Error("Redis URL failed",
|
||||
zap.String("url", redisURL),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, err
|
||||
}
|
||||
client := goredis.NewClient(opts)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
|
||||
logger.Info("Trying to reach Redis...", zap.String("addr", opts.Addr))
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
logger.Error("Redis connection Ping Failed",
|
||||
zap.String("addr", opts.Addr),
|
||||
zap.Error(err),
|
||||
)
|
||||
_ = client.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
logger.Info("Redis connected successfully", zap.String("addr", opts.Addr))
|
||||
return &RedisMessageBus{
|
||||
client: client,
|
||||
logger: logger,
|
||||
subscriptions: make(map[string]*subscription),
|
||||
serverID: serverID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Publish sends a binary message to a room channel, prepending the serverID envelope
|
||||
func (r *RedisMessageBus) Publish(ctx context.Context, roomID string, data []byte) error {
|
||||
channel := fmt.Sprintf("room:%s:messages", roomID)
|
||||
|
||||
// Prepend serverID + separator so receivers can filter self-echoes
|
||||
envelope := make([]byte, 0, len(r.serverID)+len(envelopeSeparator)+len(data))
|
||||
envelope = append(envelope, []byte(r.serverID)...)
|
||||
envelope = append(envelope, envelopeSeparator...)
|
||||
envelope = append(envelope, data...)
|
||||
|
||||
err := r.client.Publish(ctx, channel, envelope).Err()
|
||||
|
||||
if err != nil {
|
||||
r.logger.Error("failed to publish message",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Int("data_len", len(data)),
|
||||
zap.String("channel", channel),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("redis publish failed: %w", err)
|
||||
}
|
||||
r.logger.Debug("published message successfully",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Int("data_len", len(data)),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Subscribe creates a subscription to a room channel
|
||||
func (r *RedisMessageBus) Subscribe(ctx context.Context, roomID string) (<-chan []byte, error) {
|
||||
r.subMu.Lock()
|
||||
defer r.subMu.Unlock()
|
||||
|
||||
if sub, exists := r.subscriptions[roomID]; exists {
|
||||
r.logger.Debug("returning existing subscription", zap.String("roomID", roomID))
|
||||
return sub.channel, nil
|
||||
}
|
||||
|
||||
// Subscribe to Redis channel
|
||||
channel := fmt.Sprintf("room:%s:messages", roomID)
|
||||
pubsub := r.client.Subscribe(ctx, channel)
|
||||
|
||||
if _, err := pubsub.Receive(ctx); err != nil {
|
||||
pubsub.Close()
|
||||
return nil, fmt.Errorf("failed to verify subscription: %w", err)
|
||||
}
|
||||
|
||||
subCtx, cancel := context.WithCancel(context.Background())
|
||||
msgChan := make(chan []byte, 256)
|
||||
sub := &subscription{
|
||||
pubsub: pubsub,
|
||||
channel: msgChan,
|
||||
cancel: cancel,
|
||||
}
|
||||
r.subscriptions[roomID] = sub
|
||||
|
||||
go r.forwardMessages(subCtx, roomID, sub.pubsub, msgChan)
|
||||
|
||||
r.logger.Info("successfully subscribed to room",
|
||||
zap.String("roomID", roomID),
|
||||
zap.String("channel", channel),
|
||||
)
|
||||
return msgChan, nil
|
||||
}
|
||||
|
||||
// forwardMessages receives from Redis PubSub and forwards to local channel
|
||||
func (r *RedisMessageBus) forwardMessages(ctx context.Context, roomID string, pubsub *goredis.PubSub, msgChan chan []byte) {
|
||||
defer func() {
|
||||
close(msgChan)
|
||||
r.logger.Info("forwarder stopped", zap.String("roomID", roomID))
|
||||
}()
|
||||
|
||||
//Get the Redis channel from pubsub
|
||||
ch := pubsub.Channel()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
r.logger.Info("stopping the channel due to context cancellation", zap.String("roomID", roomID))
|
||||
return
|
||||
|
||||
case msg, ok := <-ch:
|
||||
// Check if channel is closed (!ok)
|
||||
if !ok {
|
||||
r.logger.Warn("redis pubsub channel closed unexpectedly", zap.String("roomID", roomID))
|
||||
return
|
||||
}
|
||||
|
||||
// Parse envelope: serverID + separator + payload
|
||||
raw := []byte(msg.Payload)
|
||||
sepIdx := bytes.Index(raw, envelopeSeparator)
|
||||
if sepIdx == -1 {
|
||||
r.logger.Warn("received message without server envelope, skipping",
|
||||
zap.String("roomID", roomID))
|
||||
continue
|
||||
}
|
||||
|
||||
senderID := string(raw[:sepIdx])
|
||||
payload := raw[sepIdx+len(envelopeSeparator):]
|
||||
|
||||
// Skip messages published by this same server (prevent echo)
|
||||
if senderID == r.serverID {
|
||||
continue
|
||||
}
|
||||
|
||||
select {
|
||||
case msgChan <- payload:
|
||||
r.logger.Debug("message forwarded",
|
||||
zap.String("roomID", roomID),
|
||||
zap.String("from_server", senderID),
|
||||
zap.Int("size", len(payload)))
|
||||
default:
|
||||
r.logger.Warn("message dropped: consumer too slow",
|
||||
zap.String("roomID", roomID))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Unsubscribe stops listening to a room
|
||||
func (r *RedisMessageBus) Unsubscribe(ctx context.Context, roomID string) error {
|
||||
r.subMu.Lock()
|
||||
defer r.subMu.Unlock()
|
||||
|
||||
// Check if subscription exists
|
||||
sub, ok := r.subscriptions[roomID]
|
||||
if !ok {
|
||||
r.logger.Debug("unsubscribe ignored: room not found", zap.String("roomID", roomID))
|
||||
return nil
|
||||
}
|
||||
// Cancel the context (stops forwardMessages goroutine)
|
||||
sub.cancel()
|
||||
|
||||
// Close the Redis pubsub connection
|
||||
if err := sub.pubsub.Close(); err != nil {
|
||||
r.logger.Error("failed to close redis pubsub",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
// Remove from subscriptions map
|
||||
delete(r.subscriptions, roomID)
|
||||
r.logger.Info("successfully unsubscribed", zap.String("roomID", roomID))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetAwareness caches awareness data in Redis Hash
|
||||
func (r *RedisMessageBus) SetAwareness(ctx context.Context, roomID string, clientID uint64, data []byte) error {
|
||||
key := fmt.Sprintf("room:%s:awareness", roomID)
|
||||
field := fmt.Sprintf("%d", clientID)
|
||||
|
||||
if err := r.client.HSet(ctx, key, field, data).Err(); err != nil {
|
||||
r.logger.Error("failed to set awareness data",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Uint64("clientID", clientID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("hset awareness failed: %w", err)
|
||||
}
|
||||
|
||||
// Set expiration on the Hash (30s)
|
||||
if err := r.client.Expire(ctx, key, 30*time.Second).Err(); err != nil {
|
||||
r.logger.Warn("failed to set expiration on awareness key",
|
||||
zap.String("key", key),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
r.logger.Debug("awareness updated",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Uint64("clientID", clientID),
|
||||
zap.Int("data_len", len(data)),
|
||||
)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllAwareness retrieves all cached awareness for a room
|
||||
func (r *RedisMessageBus) GetAllAwareness(ctx context.Context, roomID string) (map[uint64][]byte, error) {
|
||||
// 1. 构建 Redis Hash key
|
||||
key := fmt.Sprintf("room:%s:awareness", roomID)
|
||||
|
||||
// 2. 从 Redis 获取所有字段
|
||||
// HGetAll 会返回该 Hash 下所有的 field 和 value
|
||||
result, err := r.client.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
r.logger.Error("failed to HGetAll awareness",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return nil, fmt.Errorf("redis hgetall failed: %w", err)
|
||||
}
|
||||
|
||||
// 3. 转换数据类型:map[string]string -> map[uint64][]byte
|
||||
awarenessMap := make(map[uint64][]byte, len(result))
|
||||
|
||||
for field, value := range result {
|
||||
// 解析 field (clientID) 为 uint64
|
||||
// 虽然提示可以用 Sscanf,但在 Go 中 strconv.ParseUint 通常更高效且稳健
|
||||
clientID, err := strconv.ParseUint(field, 10, 64)
|
||||
if err != nil {
|
||||
r.logger.Warn("invalid clientID format in awareness hash",
|
||||
zap.String("roomID", roomID),
|
||||
zap.String("field", field),
|
||||
)
|
||||
continue // 跳过异常字段,保证其他数据正常显示
|
||||
}
|
||||
|
||||
// 将 string 转换为 []byte
|
||||
awarenessMap[clientID] = []byte(value)
|
||||
}
|
||||
|
||||
// 4. 记录日志
|
||||
r.logger.Debug("retrieved all awareness data",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Int("client_count", len(awarenessMap)),
|
||||
)
|
||||
|
||||
return awarenessMap, nil
|
||||
}
|
||||
|
||||
// DeleteAwareness removes awareness cache for a client
|
||||
func (r *RedisMessageBus) DeleteAwareness(ctx context.Context, roomID string, clientID uint64) error {
|
||||
key := fmt.Sprintf("room:%s:awareness", roomID)
|
||||
field := fmt.Sprintf("%d", clientID)
|
||||
|
||||
if err := r.client.HDel(ctx, key, field).Err(); err != nil {
|
||||
r.logger.Error("failed to delete awareness data",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Uint64("clientID", clientID),
|
||||
zap.Error(err),
|
||||
)
|
||||
return fmt.Errorf("delete awareness failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsHealthy checks Redis connectivity
|
||||
func (r *RedisMessageBus) IsHealthy() bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 只有 Ping 成功且没有报错,才认为服务是健康的
|
||||
if err := r.client.Ping(ctx).Err(); err != nil {
|
||||
r.logger.Warn("Redis health check failed", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// StartHealthMonitoring runs periodic health checks
|
||||
func (r *RedisMessageBus) StartHealthMonitoring(ctx context.Context, interval time.Duration, onStatusChange func(bool)) {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
previouslyHealthy := true
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
r.logger.Info("stopping health monitoring")
|
||||
return
|
||||
case <-ticker.C:
|
||||
// 检查当前健康状态
|
||||
currentlyHealthy := r.IsHealthy()
|
||||
|
||||
// 如果状态发生变化(健康 -> 亚健康,或 亚健康 -> 恢复)
|
||||
if currentlyHealthy != previouslyHealthy {
|
||||
r.logger.Warn("Redis health status changed",
|
||||
zap.Bool("old_status", previouslyHealthy),
|
||||
zap.Bool("new_status", currentlyHealthy),
|
||||
)
|
||||
|
||||
// 触发外部回调逻辑
|
||||
if onStatusChange != nil {
|
||||
onStatusChange(currentlyHealthy)
|
||||
}
|
||||
|
||||
// 更新历史状态
|
||||
previouslyHealthy = currentlyHealthy
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *RedisMessageBus) Close() error {
|
||||
r.subMu.Lock()
|
||||
defer r.subMu.Unlock()
|
||||
|
||||
r.logger.Info("gracefully shutting down message bus", zap.Int("active_subs", len(r.subscriptions)))
|
||||
|
||||
// 1. 关闭所有正在运行的订阅
|
||||
for roomID, sub := range r.subscriptions {
|
||||
// 停止对应的 forwardMessages 协程
|
||||
sub.cancel()
|
||||
|
||||
// 关闭物理连接
|
||||
if err := sub.pubsub.Close(); err != nil {
|
||||
r.logger.Error("failed to close pubsub connection",
|
||||
zap.String("roomID", roomID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 清空 Map,释放引用以便 GC 回收
|
||||
r.subscriptions = make(map[string]*subscription)
|
||||
|
||||
// 3. 关闭主 Redis 客户端连接池
|
||||
if err := r.client.Close(); err != nil {
|
||||
r.logger.Error("failed to close redis client", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
|
||||
r.logger.Info("Redis message bus closed successfully")
|
||||
return nil
|
||||
}
|
||||
// ClearAllAwareness 彻底删除该房间的感知数据 Hash
|
||||
func (r *RedisMessageBus) ClearAllAwareness(ctx context.Context, roomID string) error {
|
||||
key := fmt.Sprintf("room:%s:awareness", roomID)
|
||||
// 直接使用 Del 命令删除整个 Key
|
||||
return r.client.Del(ctx, key).Err()
|
||||
}
|
||||
Reference in New Issue
Block a user