Auth middleware

public
kalebo2023 Apr 17, 2025 Never 24
Clone
Go Auth.go 87 lines (74 loc) | 2.61 KB
1
package middleware
2
3
import (
4
"context"
5
"crypto/hmac"
6
"crypto/sha256"
7
"encoding/base64"
8
"encoding/hex"
9
"log"
10
"net/http"
11
"net/url"
12
"os"
13
"strings"
14
)
15
16
type contextKey string
17
18
const userIDContextKey contextKey = "user_id"
19
20
func AuthMiddleware(next http.Handler) http.Handler {
21
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
22
// Step 1: Retrieve the session cookie
23
cookie, err := r.Cookie("better_auth_session")
24
if err != nil {
25
log.Printf("Missing cookie: %v", err)
26
http.Error(w, "Unauthorized", http.StatusUnauthorized)
27
return
28
}
29
30
// Step 2: URL decode the cookie value
31
decodedValue, err := url.QueryUnescape(cookie.Value)
32
if err != nil {
33
log.Printf("URL decode error: %v", err)
34
http.Error(w, "Invalid cookie", http.StatusBadRequest)
35
return
36
}
37
38
// Step 3: Separate the value and signature using '.'
39
parts := strings.Split(decodedValue, ".")
40
if len(parts) != 2 {
41
log.Printf("Invalid cookie format: %s", decodedValue)
42
http.Error(w, "Unauthorized", http.StatusUnauthorized)
43
return
44
}
45
46
encodedValue := parts[0]
47
signature := parts[1]
48
49
// Step 4: Verify the signature
50
h := hmac.New(sha256.New, []byte(os.Getenv("BETTER_AUTH_SECRET")))
51
h.Write([]byte(encodedValue))
52
expectedSignature := base64.StdEncoding.EncodeToString(h.Sum(nil))
53
if signature != expectedSignature {
54
log.Printf("Signature mismatch: got %s, expected %s", signature, expectedSignature)
55
http.Error(w, "Unauthorized", http.StatusUnauthorized)
56
return
57
}
58
59
// Step 5: Base64 decode the value to get raw bytes (session data)
60
valueBytes, err := base64.StdEncoding.DecodeString(encodedValue)
61
if err != nil {
62
log.Printf("Base64 decode error: %v, encodedValue: %s", err, encodedValue)
63
http.Error(w, "Invalid base64", http.StatusBadRequest)
64
return
65
}
66
67
// Step 6: Convert raw bytes to hex (or use as raw session data if needed)
68
userID := hex.EncodeToString(valueBytes) // Adjust this if you expect a different format
69
if userID == "" {
70
log.Printf("Empty user ID after decoding")
71
http.Error(w, "Unauthorized", http.StatusUnauthorized)
72
return
73
}
74
75
log.Printf("Authenticated user ID: %s (raw bytes: %x, encoded: %s)", userID, valueBytes, encodedValue)
76
77
// Step 7: Set user ID in context
78
ctx := context.WithValue(r.Context(), userIDContextKey, userID)
79
next.ServeHTTP(w, r.WithContext(ctx))
80
})
81
}
82
83
// GetUserIDFromContext retrieves the authenticated user ID from context
84
func GetUserIDFromContext(ctx context.Context) (string, bool) {
85
userID, ok := ctx.Value(userIDContextKey).(string)
86
return userID, ok
87
}