Refactor and improve code consistency across multiple files

- Enhanced SQL queries in `session.go` and `share.go` for clarity and consistency.
- Updated comments for better understanding and maintenance.
- Ensured consistent error handling and return statements across various methods.
This commit is contained in:
M1ngdaXie
2026-02-04 22:01:47 -08:00
parent 0f4cff89a2
commit c84cbafb2c
18 changed files with 629 additions and 631 deletions

View File

@@ -64,10 +64,10 @@ func (h *AuthHandler) GoogleLogin(c *gin.Context) {
// GoogleCallback handles Google OAuth callback
func (h *AuthHandler) GoogleCallback(c *gin.Context) {
oauthState, err := c.Cookie("oauthstate")
if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return
}
if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return
}
log.Println("Google callback state:", c.Query("state"))
// Exchange code for token
token, err := h.googleConfig.Exchange(c.Request.Context(), c.Query("code"))
@@ -94,11 +94,11 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
Name string `json:"name"`
Picture string `json:"picture"`
}
if err := json.Unmarshal(data, &userInfo); err != nil {
log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"})
return
log.Printf("Failed to parse Google response: %v | Data: %s", err, string(data))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid Google response"})
return
}
log.Println("Google user info:", userInfo)
// Upsert user in database
@@ -118,10 +118,10 @@ func (h *AuthHandler) GoogleCallback(c *gin.Context) {
// Create session and JWT
jwt, err := h.createSessionAndJWT(c, user)
if err != nil {
fmt.Printf("❌ DATABASE ERROR: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("CreateSession Error: %v", err),
})
fmt.Printf("❌ DATABASE ERROR: %v\n", err)
c.JSON(http.StatusInternalServerError, gin.H{
"error": fmt.Sprintf("CreateSession Error: %v", err),
})
return
}
@@ -140,10 +140,10 @@ func (h *AuthHandler) GithubLogin(c *gin.Context) {
// GithubCallback handles GitHub OAuth callback
func (h *AuthHandler) GithubCallback(c *gin.Context) {
oauthState, err := c.Cookie("oauthstate")
if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return
}
if err != nil || c.Query("state") != oauthState {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid oauth state"})
return
}
log.Println("Github callback state:", c.Query("state"))
code := c.Query("code")
if code == "" {
@@ -160,7 +160,7 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
// Get user info from GitHub
client := h.githubConfig.Client(c.Request.Context(), token)
// Get user profile
resp, err := client.Get("https://api.github.com/user")
if err != nil {
@@ -178,10 +178,10 @@ func (h *AuthHandler) GithubCallback(c *gin.Context) {
AvatarURL string `json:"avatar_url"`
}
if err := json.Unmarshal(data, &userInfo); err != nil {
log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"})
return
}
log.Printf("Failed to parse GitHub response: %v | Data: %s", err, string(data))
c.JSON(http.StatusInternalServerError, gin.H{"error": "Invalid GitHub response"})
return
}
// If email is not public, fetch it separately
if userInfo.Email == "" {
@@ -315,10 +315,10 @@ func (h *AuthHandler) generateStateOauthCookie(w http.ResponseWriter) string {
Name: "oauthstate",
Value: state,
Expires: time.Now().Add(10 * time.Minute),
HttpOnly: true, // Prevents JavaScript access (XSS protection)
Secure: h.cfg.SecureCookie, // true in production, false for localhost
SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects
Path: "/", // Ensures cookie is sent to all backend paths
HttpOnly: true, // Prevents JavaScript access (XSS protection)
Secure: h.cfg.SecureCookie, // true in production, false for localhost
SameSite: http.SameSiteLaxMode, // Allows same-site OAuth redirects
Path: "/", // Ensures cookie is sent to all backend paths
}
http.SetCookie(w, &cookie)

View File

@@ -11,16 +11,15 @@ import (
"github.com/google/uuid"
)
type DocumentHandler struct {
store *store.PostgresStore
}
type DocumentHandler struct {
store *store.PostgresStore
}
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
return &DocumentHandler{store: s}
}
func NewDocumentHandler(s *store.PostgresStore) *DocumentHandler {
return &DocumentHandler{store: s}
}
// CreateDocument creates a new document (requires auth)
// CreateDocument creates a new document (requires auth)
func (h *DocumentHandler) CreateDocument(c *gin.Context) {
userID := auth.GetUserFromContext(c)
if userID == nil {
@@ -44,7 +43,7 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
c.JSON(http.StatusCreated, doc)
}
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
func (h *DocumentHandler) ListDocuments(c *gin.Context) {
userID := auth.GetUserFromContext(c)
fmt.Println("Getting userId, which is : ")
fmt.Println(userID)
@@ -66,8 +65,7 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
})
}
func (h *DocumentHandler) GetDocument(c *gin.Context) {
func (h *DocumentHandler) GetDocument(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
respondBadRequest(c, "Invalid document ID format")
@@ -104,8 +102,9 @@ func (h *DocumentHandler) CreateDocument(c *gin.Context) {
c.JSON(http.StatusOK, doc)
}
// GetDocumentState returns the Yjs state for a document
// GetDocumentState retrieves document state (requires view permission)
// GetDocumentState returns the Yjs state for a document
// GetDocumentState retrieves document state (requires view permission)
func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
@@ -143,7 +142,7 @@ func (h *DocumentHandler) GetDocumentState(c *gin.Context) {
c.Data(http.StatusOK, "application/octet-stream", state)
}
// UpdateDocumentState updates document state (requires edit permission)
// UpdateDocumentState updates document state (requires edit permission)
func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
@@ -195,7 +194,7 @@ func (h *DocumentHandler) UpdateDocumentState(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"message": "State updated successfully"})
}
// DeleteDocument deletes a document (owner only)
// DeleteDocument deletes a document (owner only)
func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
id, err := uuid.Parse(c.Param("id"))
if err != nil {
@@ -237,107 +236,109 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) {
// GetDocumentPermission returns the user's permission level for a document
func (h *DocumentHandler) GetDocumentPermission(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id"))
if err != nil {
respondBadRequest(c, "Invalid document ID format")
return
}
documentID, err := uuid.Parse(c.Param("id"))
if err != nil {
respondBadRequest(c, "Invalid document ID format")
return
}
// 1. 先检查文档是否存在 (Good practice)
_, err = h.store.GetDocument(documentID)
if err != nil {
respondNotFound(c, "document")
return
}
// 1. 先检查文档是否存在 (Good practice)
_, err = h.store.GetDocument(documentID)
if err != nil {
respondNotFound(c, "document")
return
}
userID := auth.GetUserFromContext(c)
shareToken := c.Query("share")
userID := auth.GetUserFromContext(c)
shareToken := c.Query("share")
// 定义两个临时变量,用来存两边的结果
var userPerm string // 存 document_shares 的结果
var tokenPerm string // 存 share_token 的结果
// 定义两个临时变量,用来存两边的结果
var userPerm string // 存 document_shares 的结果
var tokenPerm string // 存 share_token 的结果
// ====================================================
// 步骤 A: 检查个人权限 (Base Permission)
// ====================================================
if userID != nil {
perm, err := h.store.GetUserPermission(c.Request.Context(), documentID, *userID)
if err != nil {
respondInternalError(c, "Failed to get user permission", err)
return
}
userPerm = perm
// ⚠️ 注意:如果 perm 是空,这里不报错!继续往下走!
}
// ====================================================
// 步骤 A: 检查个人权限 (Base Permission)
// ====================================================
if userID != nil {
perm, err := h.store.GetUserPermission(c.Request.Context(), documentID, *userID)
if err != nil {
respondInternalError(c, "Failed to get user permission", err)
return
}
userPerm = perm
// ⚠️ 注意:如果 perm 是空,这里不报错!继续往下走!
}
// ====================================================
// 步骤 B: 检查 Token 权限 (Upgrade Permission)
// ====================================================
if shareToken != "" {
// 先验证 Token 是否有效
valid, err := h.store.ValidateShareToken(c.Request.Context(), documentID, shareToken)
if err != nil {
respondInternalError(c, "Failed to validate token", err)
return
}
// 只有 Token 有效才去取权限
if valid {
p, err := h.store.GetShareLinkPermission(c.Request.Context(), documentID)
if err != nil {
respondInternalError(c, "Failed to get token permission", err)
return
}
tokenPerm = p
// 处理数据库老数据的 fallback
if tokenPerm == "" { tokenPerm = "view" }
}
}
// ====================================================
// 步骤 B: 检查 Token 权限 (Upgrade Permission)
// ====================================================
if shareToken != "" {
// 先验证 Token 是否有效
valid, err := h.store.ValidateShareToken(c.Request.Context(), documentID, shareToken)
if err != nil {
respondInternalError(c, "Failed to validate token", err)
return
}
// ====================================================
// 步骤 C: ⚡️ 权限合并与计算 (The Brain)
// ====================================================
finalPermission := ""
role := "viewer" // 默认角色
// 只有 Token 有效才去取权限
if valid {
p, err := h.store.GetShareLinkPermission(c.Request.Context(), documentID)
if err != nil {
respondInternalError(c, "Failed to get token permission", err)
return
}
tokenPerm = p
// 处理数据库老数据的 fallback
if tokenPerm == "" {
tokenPerm = "view"
}
}
}
// 1. 如果是 Owner无敌直接返回
if userPerm == "owner" {
finalPermission = "edit"
role = "owner"
// 直接返回,不用看 Token 了
c.JSON(http.StatusOK, models.PermissionResponse{
Permission: finalPermission,
Role: role,
})
return
}
// ====================================================
// 步骤 C: ⚡️ 权限合并与计算 (The Brain)
// ====================================================
// 2. 比较 User 和 Token取最大值
// 逻辑:只要任意一边给了 "edit",那就是 "edit"
if userPerm == "edit" || tokenPerm == "edit" {
finalPermission = "edit"
role = "editor"
} else if userPerm == "view" || tokenPerm == "view" {
finalPermission = "view"
role = "viewer"
}
finalPermission := ""
role := "viewer" // 默认角色
// ====================================================
// 步骤 D: 最终判决
// ====================================================
if finalPermission == "" {
// 既没个人权限Token 也不对(或者没 Token
if userID == nil {
respondUnauthorized(c, "Authentication required") // 没登录且没Token
} else {
respondForbidden(c, "You don't have permission") // 登录了但没权限
}
return
}
// 1. 如果是 Owner无敌直接返回
if userPerm == "owner" {
finalPermission = "edit"
role = "owner"
// 直接返回,不用看 Token
c.JSON(http.StatusOK, models.PermissionResponse{
Permission: finalPermission,
Role: role,
})
return
}
c.JSON(http.StatusOK, models.PermissionResponse{
Permission: finalPermission,
Role: role,
})
}
// 2. 比较 User 和 Token取最大值
// 逻辑:只要任意一边给了 "edit",那就是 "edit"
if userPerm == "edit" || tokenPerm == "edit" {
finalPermission = "edit"
role = "editor"
} else if userPerm == "view" || tokenPerm == "view" {
finalPermission = "view"
role = "viewer"
}
// ====================================================
// 步骤 D: 最终判决
// ====================================================
if finalPermission == "" {
// 既没个人权限Token 也不对(或者没 Token
if userID == nil {
respondUnauthorized(c, "Authentication required") // 没登录且没Token
} else {
respondForbidden(c, "You don't have permission") // 登录了但没权限
}
return
}
c.JSON(http.StatusOK, models.PermissionResponse{
Permission: finalPermission,
Role: role,
})
}

View File

@@ -91,5 +91,5 @@ func respondInternalError(c *gin.Context, message string, err error) {
respondWithError(c, http.StatusInternalServerError, "internal_error", message)
}
func respondInvalidID(c *gin.Context, message string) {
respondWithError(c, http.StatusBadRequest, "invalid_id", message)
respondWithError(c, http.StatusBadRequest, "invalid_id", message)
}

View File

@@ -153,6 +153,7 @@ func (h *ShareHandler) DeleteShare(c *gin.Context) {
c.Status(204)
}
// CreateShareLink generates a public share link
func (h *ShareHandler) CreateShareLink(c *gin.Context) {
documentID, err := uuid.Parse(c.Param("id"))
@@ -290,4 +291,4 @@ func (h *ShareHandler) RevokeShareLink(c *gin.Context) {
// c.JSON(http.StatusOK, gin.H{"message": "Share link revoked successfully"})
c.Status(204)
}
}

View File

@@ -70,7 +70,7 @@ func TestShareHandlerSuite(t *testing.T) {
func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() {
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "view",
}
@@ -87,7 +87,7 @@ func (s *ShareHandlerSuite) TestCreateShare_ViewPermission() {
func (s *ShareHandlerSuite) TestCreateShare_EditPermission() {
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "edit",
}
@@ -104,7 +104,7 @@ func (s *ShareHandlerSuite) TestCreateShare_EditPermission() {
func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() {
body := map[string]interface{}{
"user_email": "charlie@test.com",
"user_email": "charlie@test.com",
"permission": "view",
}
@@ -118,7 +118,7 @@ func (s *ShareHandlerSuite) TestCreateShare_NonOwnerDenied() {
func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() {
body := map[string]interface{}{
"user_email": "nonexistent@test.com",
"user_email": "nonexistent@test.com",
"permission": "view",
}
@@ -131,7 +131,7 @@ func (s *ShareHandlerSuite) TestCreateShare_UserNotFound() {
func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() {
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "admin", // Invalid permission
}
@@ -145,7 +145,7 @@ func (s *ShareHandlerSuite) TestCreateShare_InvalidPermission() {
func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() {
// Create initial share with view permission
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "view",
}
@@ -169,7 +169,7 @@ func (s *ShareHandlerSuite) TestCreateShare_UpdatesExisting() {
func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() {
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "view",
}
@@ -182,7 +182,7 @@ func (s *ShareHandlerSuite) TestCreateShare_Unauthorized() {
func (s *ShareHandlerSuite) TestCreateShare_InvalidDocumentID() {
body := map[string]interface{}{
"user_email": "bob@test.com",
"user_email": "bob@test.com",
"permission": "view",
}
@@ -206,8 +206,8 @@ func (s *ShareHandlerSuite) TestListShares_OwnerSeesAll() {
s.assertSuccessResponse(w, http.StatusOK)
var response models.ShareListResponse
s.parseJSONResponse(w, &response)
shares := response.Shares
s.parseJSONResponse(w, &response)
shares := response.Shares
s.GreaterOrEqual(len(shares), 1, "Should have at least one share")
}
@@ -229,8 +229,8 @@ func (s *ShareHandlerSuite) TestListShares_EmptyList() {
s.assertSuccessResponse(w, http.StatusOK)
var response models.ShareListResponse
s.parseJSONResponse(w, &response)
shares := response.Shares
s.parseJSONResponse(w, &response)
shares := response.Shares
s.Equal(0, len(shares), "Should have no shares")
}
@@ -243,8 +243,8 @@ func (s *ShareHandlerSuite) TestListShares_IncludesUserDetails() {
s.assertSuccessResponse(w, http.StatusOK)
var response models.ShareListResponse
s.parseJSONResponse(w, &response)
shares := response.Shares
s.parseJSONResponse(w, &response)
shares := response.Shares
if len(shares) > 0 {
share := shares[0]
@@ -266,7 +266,7 @@ func (s *ShareHandlerSuite) TestListShares_OrderedByCreatedAt() {
users := []string{"bob@test.com", "charlie@test.com"}
for _, email := range users {
body := map[string]interface{}{
"user_email": email,
"user_email": email,
"permission": "view",
}
w, httpReq, err := s.makeAuthRequest("POST", fmt.Sprintf("/api/documents/%s/shares", s.testData.AlicePrivateDoc), body, s.testData.AliceID)

View File

@@ -48,10 +48,10 @@ func (wsh *WebSocketHandler) HandleWebSocketLoadTest(c *gin.Context) {
clientID := uuid.New().String()
client := hub.NewClient(
clientID,
nil, // userID - nil for anonymous
userName, // userName
nil, // userAvatar
"edit", // permission - full access for load testing
nil, // userID - nil for anonymous
userName, // userName
nil, // userAvatar
"edit", // permission - full access for load testing
conn,
wsh.hub,
roomID,