Refactor API configuration and improve WebSocket handling in frontend and backend
This commit is contained in:
@@ -3,10 +3,9 @@ package handlers
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/config"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/hub"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -14,36 +13,33 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
|
||||
if allowedOrigins == "" {
|
||||
// Default for development
|
||||
return origin == "http://localhost:5173" || origin == "http://localhost:3000"
|
||||
}
|
||||
// Production: validate against ALLOWED_ORIGINS
|
||||
origins := strings.Split(allowedOrigins, ",")
|
||||
for _, allowed := range origins {
|
||||
if strings.TrimSpace(allowed) == origin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
|
||||
type WebSocketHandler struct {
|
||||
hub *hub.Hub
|
||||
store store.Store
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewWebSocketHandler(h *hub.Hub, s store.Store) *WebSocketHandler {
|
||||
func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
hub: h,
|
||||
store: s,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (wsh *WebSocketHandler) getUpgrader() websocket.Upgrader {
|
||||
return websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
for _, allowed := range wsh.cfg.AllowedOrigins {
|
||||
if allowed == origin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,16 +66,8 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||
// Check for JWT token in query parameter
|
||||
jwtToken := c.Query("token")
|
||||
if jwtToken != "" {
|
||||
// Validate JWT signature and expiration - STATELESS, no DB query!
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
log.Println("JWT_SECRET not configured")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Server configuration error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Direct JWT validation - fast path (~1ms)
|
||||
claims, err := auth.ValidateJWT(jwtToken, jwtSecret)
|
||||
claims, err := auth.ValidateJWT(jwtToken, wsh.cfg.JWTSecret)
|
||||
if err == nil {
|
||||
// Extract user data from JWT claims
|
||||
uid, parseErr := uuid.Parse(claims.Subject)
|
||||
@@ -151,6 +139,7 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Upgrade connection
|
||||
upgrader := wsh.getUpgrader()
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to upgrade connection: %v", err)
|
||||
|
||||
Reference in New Issue
Block a user