Refactor API configuration and improve WebSocket handling in frontend and backend
This commit is contained in:
@@ -8,10 +8,10 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/config"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -20,41 +20,45 @@ import (
|
||||
)
|
||||
|
||||
type AuthHandler struct {
|
||||
store store.Store
|
||||
googleConfig *oauth2.Config
|
||||
githubConfig *oauth2.Config
|
||||
jwtSecret string
|
||||
frontendURL string
|
||||
store store.Store
|
||||
cfg *config.Config
|
||||
googleConfig *oauth2.Config
|
||||
githubConfig *oauth2.Config
|
||||
}
|
||||
|
||||
func NewAuthHandler(store store.Store, jwtSecret, frontendURL string) *AuthHandler {
|
||||
googleConfig := auth.GetGoogleOAuthConfig(
|
||||
os.Getenv("GOOGLE_CLIENT_ID"),
|
||||
os.Getenv("GOOGLE_CLIENT_SECRET"),
|
||||
os.Getenv("GOOGLE_REDIRECT_URL"),
|
||||
)
|
||||
func NewAuthHandler(store store.Store, cfg *config.Config) *AuthHandler {
|
||||
var googleConfig *oauth2.Config
|
||||
if cfg.HasGoogleOAuth() {
|
||||
googleConfig = auth.GetGoogleOAuthConfig(
|
||||
cfg.GoogleClientID,
|
||||
cfg.GoogleClientSecret,
|
||||
cfg.GoogleRedirectURL,
|
||||
)
|
||||
}
|
||||
|
||||
githubConfig := auth.GetGitHubOAuthConfig(
|
||||
os.Getenv("GITHUB_CLIENT_ID"),
|
||||
os.Getenv("GITHUB_CLIENT_SECRET"),
|
||||
os.Getenv("GITHUB_REDIRECT_URL"),
|
||||
)
|
||||
var githubConfig *oauth2.Config
|
||||
if cfg.HasGitHubOAuth() {
|
||||
githubConfig = auth.GetGitHubOAuthConfig(
|
||||
cfg.GitHubClientID,
|
||||
cfg.GitHubClientSecret,
|
||||
cfg.GitHubRedirectURL,
|
||||
)
|
||||
}
|
||||
|
||||
return &AuthHandler{
|
||||
store: store,
|
||||
cfg: cfg,
|
||||
googleConfig: googleConfig,
|
||||
githubConfig: githubConfig,
|
||||
jwtSecret: jwtSecret,
|
||||
frontendURL: frontendURL,
|
||||
}
|
||||
}
|
||||
|
||||
// GoogleLogin redirects to Google OAuth
|
||||
func (h *AuthHandler) GoogleLogin(c *gin.Context) {
|
||||
// Generate random state and set cookie
|
||||
oauthState := generateStateOauthCookie(c.Writer)
|
||||
url := h.googleConfig.AuthCodeURL(oauthState, oauth2.AccessTypeOffline)
|
||||
c.Redirect(http.StatusTemporaryRedirect, url)
|
||||
// Generate random state and set cookie
|
||||
oauthState := h.generateStateOauthCookie(c.Writer)
|
||||
url := h.googleConfig.AuthCodeURL(oauthState, oauth2.AccessTypeOffline)
|
||||
c.Redirect(http.StatusTemporaryRedirect, url)
|
||||
}
|
||||
|
||||
// GoogleCallback handles Google OAuth callback
|
||||
@@ -122,15 +126,15 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Redirect to frontend with token
|
||||
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s", h.frontendURL, jwt)
|
||||
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s", h.cfg.FrontendURL, jwt)
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
}
|
||||
|
||||
// GithubLogin redirects to GitHub OAuth
|
||||
func (h *AuthHandler) GithubLogin(c *gin.Context) {
|
||||
oauthState := generateStateOauthCookie(c.Writer)
|
||||
url := h.githubConfig.AuthCodeURL(oauthState)
|
||||
c.Redirect(http.StatusTemporaryRedirect, url)
|
||||
oauthState := h.generateStateOauthCookie(c.Writer)
|
||||
url := h.githubConfig.AuthCodeURL(oauthState)
|
||||
c.Redirect(http.StatusTemporaryRedirect, url)
|
||||
}
|
||||
|
||||
// GithubCallback handles GitHub OAuth callback
|
||||
@@ -227,7 +231,7 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Redirect to frontend with token
|
||||
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s", h.frontendURL, jwt)
|
||||
redirectURL := fmt.Sprintf("%s/auth/callback?token=%s", h.cfg.FrontendURL, jwt)
|
||||
c.Redirect(http.StatusTemporaryRedirect, redirectURL)
|
||||
}
|
||||
|
||||
@@ -274,7 +278,7 @@ func (h *AuthHandler) createSessionAndJWT(c *gin.Context, user *models.User) (st
|
||||
expiresAt := time.Now().Add(7 * 24 * time.Hour) // 7 days
|
||||
|
||||
// Generate JWT first (we need it for session) - now includes avatar URL
|
||||
jwt, err := auth.GenerateJWT(user.ID, user.Name, user.Email, user.AvatarURL, h.jwtSecret, 7*24*time.Hour)
|
||||
jwt, err := auth.GenerateJWT(user.ID, user.Name, user.Email, user.AvatarURL, h.cfg.JWTSecret, 7*24*time.Hour)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -298,25 +302,25 @@ func (h *AuthHandler) createSessionAndJWT(c *gin.Context, user *models.User) (st
|
||||
|
||||
return jwt, nil
|
||||
}
|
||||
func generateStateOauthCookie(w http.ResponseWriter) string {
|
||||
b := make([]byte, 16)
|
||||
n, err := rand.Read(b)
|
||||
if err != nil || n != 16 {
|
||||
fmt.Printf("Failed to generate random state: %v\n", err)
|
||||
return "" // Critical for CSRF security
|
||||
}
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
func (h *AuthHandler) generateStateOauthCookie(w http.ResponseWriter) string {
|
||||
b := make([]byte, 16)
|
||||
n, err := rand.Read(b)
|
||||
if err != nil || n != 16 {
|
||||
fmt.Printf("Failed to generate random state: %v\n", err)
|
||||
return "" // Critical for CSRF security
|
||||
}
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
|
||||
cookie := http.Cookie{
|
||||
Name: "oauthstate",
|
||||
Value: state,
|
||||
Expires: time.Now().Add(10 * time.Minute),
|
||||
HttpOnly: true, // Prevents JavaScript access (XSS protection)
|
||||
Secure: false, // Must be false for http://localhost (set true in production)
|
||||
SameSite: http.SameSiteLaxMode, // ✅ Allows same-site OAuth redirects
|
||||
Path: "/", // ✅ Ensures cookie is sent to all backend paths
|
||||
}
|
||||
http.SetCookie(w, &cookie)
|
||||
cookie := http.Cookie{
|
||||
Name: "oauthstate",
|
||||
Value: state,
|
||||
Expires: time.Now().Add(10 * time.Minute),
|
||||
HttpOnly: true, // Prevents JavaScript access (XSS protection)
|
||||
Secure: h.cfg.SecureCookie, // true in production, false for localhost
|
||||
SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects
|
||||
Path: "/", // Ensures cookie is sent to all backend paths
|
||||
}
|
||||
http.SetCookie(w, &cookie)
|
||||
|
||||
return state
|
||||
return state
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package handlers
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os" // Add this
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/config"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/models"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -14,10 +14,11 @@ import (
|
||||
|
||||
type ShareHandler struct {
|
||||
store store.Store
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewShareHandler(store store.Store) *ShareHandler {
|
||||
return &ShareHandler{store: store}
|
||||
func NewShareHandler(store store.Store, cfg *config.Config) *ShareHandler {
|
||||
return &ShareHandler{store: store, cfg: cfg}
|
||||
}
|
||||
|
||||
// CreateShare creates a new document share
|
||||
@@ -193,13 +194,7 @@ func (h *ShareHandler) CreateShareLink(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get frontend URL from env
|
||||
frontendURL := os.Getenv("FRONTEND_URL")
|
||||
if frontendURL == "" {
|
||||
frontendURL = "http://localhost:5173"
|
||||
}
|
||||
|
||||
shareURL := fmt.Sprintf("%s/editor/%s?share=%s", frontendURL, documentID.String(), token)
|
||||
shareURL := fmt.Sprintf("%s/editor/%s?share=%s", h.cfg.FrontendURL, documentID.String(), token)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"url": shareURL,
|
||||
@@ -253,12 +248,7 @@ func (h *ShareHandler) GetShareLink(c *gin.Context) {
|
||||
permission = "edit" // Default fallback
|
||||
}
|
||||
|
||||
frontendURL := os.Getenv("FRONTEND_URL")
|
||||
if frontendURL == "" {
|
||||
frontendURL = "http://localhost:5173"
|
||||
}
|
||||
|
||||
shareURL := fmt.Sprintf("%s/editor/%s?share=%s", frontendURL, documentID.String(), token)
|
||||
shareURL := fmt.Sprintf("%s/editor/%s?share=%s", h.cfg.FrontendURL, documentID.String(), token)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"url": shareURL,
|
||||
|
||||
@@ -25,7 +25,7 @@ func (s *ShareHandlerSuite) SetupTest() {
|
||||
|
||||
// Create handler and router
|
||||
authMiddleware := auth.NewAuthMiddleware(s.store, s.jwtSecret)
|
||||
s.handler = NewShareHandler(s.store)
|
||||
s.handler = NewShareHandler(s.store, s.cfg)
|
||||
s.router = gin.New()
|
||||
|
||||
// Custom auth middleware for tests that sets user_id as pointer
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/config"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
@@ -17,18 +18,26 @@ import (
|
||||
// BaseHandlerSuite provides common setup for all handler tests
|
||||
type BaseHandlerSuite struct {
|
||||
suite.Suite
|
||||
store *store.PostgresStore
|
||||
cleanup func()
|
||||
testData *store.TestData
|
||||
jwtSecret string
|
||||
frontendURL string
|
||||
store *store.PostgresStore
|
||||
cleanup func()
|
||||
testData *store.TestData
|
||||
jwtSecret string
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// SetupSuite runs once before all tests in the suite
|
||||
func (s *BaseHandlerSuite) SetupSuite() {
|
||||
s.store, s.cleanup = store.SetupTestDB(s.T())
|
||||
s.jwtSecret = "test-secret-key-do-not-use-in-production"
|
||||
s.frontendURL = "http://localhost:5173"
|
||||
s.cfg = &config.Config{
|
||||
Port: "8080",
|
||||
Environment: "development",
|
||||
JWTSecret: s.jwtSecret,
|
||||
FrontendURL: "http://localhost:5173",
|
||||
BackendURL: "http://localhost:8080",
|
||||
AllowedOrigins: []string{"http://localhost:5173", "http://localhost:3000"},
|
||||
SecureCookie: false,
|
||||
}
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
|
||||
@@ -3,10 +3,9 @@ package handlers
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/auth"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/config"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/hub"
|
||||
"github.com/M1ngdaXie/realtime-collab/internal/store"
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -14,36 +13,33 @@ import (
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
allowedOrigins := os.Getenv("ALLOWED_ORIGINS")
|
||||
if allowedOrigins == "" {
|
||||
// Default for development
|
||||
return origin == "http://localhost:5173" || origin == "http://localhost:3000"
|
||||
}
|
||||
// Production: validate against ALLOWED_ORIGINS
|
||||
origins := strings.Split(allowedOrigins, ",")
|
||||
for _, allowed := range origins {
|
||||
if strings.TrimSpace(allowed) == origin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
|
||||
type WebSocketHandler struct {
|
||||
hub *hub.Hub
|
||||
store store.Store
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewWebSocketHandler(h *hub.Hub, s store.Store) *WebSocketHandler {
|
||||
func NewWebSocketHandler(h *hub.Hub, s store.Store, cfg *config.Config) *WebSocketHandler {
|
||||
return &WebSocketHandler{
|
||||
hub: h,
|
||||
store: s,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (wsh *WebSocketHandler) getUpgrader() websocket.Upgrader {
|
||||
return websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := r.Header.Get("Origin")
|
||||
for _, allowed := range wsh.cfg.AllowedOrigins {
|
||||
if allowed == origin {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,16 +66,8 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||
// Check for JWT token in query parameter
|
||||
jwtToken := c.Query("token")
|
||||
if jwtToken != "" {
|
||||
// Validate JWT signature and expiration - STATELESS, no DB query!
|
||||
jwtSecret := os.Getenv("JWT_SECRET")
|
||||
if jwtSecret == "" {
|
||||
log.Println("JWT_SECRET not configured")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "Server configuration error"})
|
||||
return
|
||||
}
|
||||
|
||||
// Direct JWT validation - fast path (~1ms)
|
||||
claims, err := auth.ValidateJWT(jwtToken, jwtSecret)
|
||||
claims, err := auth.ValidateJWT(jwtToken, wsh.cfg.JWTSecret)
|
||||
if err == nil {
|
||||
// Extract user data from JWT claims
|
||||
uid, parseErr := uuid.Parse(claims.Subject)
|
||||
@@ -151,6 +139,7 @@ func (wsh *WebSocketHandler) HandleWebSocket(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Upgrade connection
|
||||
upgrader := wsh.getUpgrader()
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
log.Printf("Failed to upgrade connection: %v", err)
|
||||
|
||||
Reference in New Issue
Block a user