1 | package middleware |
2 | |
3 | import ( |
4 | "context" |
5 | "crypto/hmac" |
6 | "crypto/sha256" |
7 | "encoding/base64" |
8 | "encoding/hex" |
9 | "log" |
10 | "net/http" |
11 | "net/url" |
12 | "os" |
13 | "strings" |
14 | ) |
15 | |
16 | type contextKey string |
17 | |
18 | const userIDContextKey contextKey = "user_id" |
19 | |
20 | func AuthMiddleware(next http.Handler) http.Handler { |
21 | return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
22 | |
23 | cookie, err := r.Cookie("better_auth_session") |
24 | if err != nil { |
25 | log.Printf("Missing cookie: %v", err) |
26 | http.Error(w, "Unauthorized", http.StatusUnauthorized) |
27 | return |
28 | } |
29 | |
30 | |
31 | decodedValue, err := url.QueryUnescape(cookie.Value) |
32 | if err != nil { |
33 | log.Printf("URL decode error: %v", err) |
34 | http.Error(w, "Invalid cookie", http.StatusBadRequest) |
35 | return |
36 | } |
37 | |
38 | |
39 | parts := strings.Split(decodedValue, ".") |
40 | if len(parts) != 2 { |
41 | log.Printf("Invalid cookie format: %s", decodedValue) |
42 | http.Error(w, "Unauthorized", http.StatusUnauthorized) |
43 | return |
44 | } |
45 | |
46 | encodedValue := parts[0] |
47 | signature := parts[1] |
48 | |
49 | |
50 | h := hmac.New(sha256.New, []byte(os.Getenv("BETTER_AUTH_SECRET"))) |
51 | h.Write([]byte(encodedValue)) |
52 | expectedSignature := base64.StdEncoding.EncodeToString(h.Sum(nil)) |
53 | if signature != expectedSignature { |
54 | log.Printf("Signature mismatch: got %s, expected %s", signature, expectedSignature) |
55 | http.Error(w, "Unauthorized", http.StatusUnauthorized) |
56 | return |
57 | } |
58 | |
59 | |
60 | valueBytes, err := base64.StdEncoding.DecodeString(encodedValue) |
61 | if err != nil { |
62 | log.Printf("Base64 decode error: %v, encodedValue: %s", err, encodedValue) |
63 | http.Error(w, "Invalid base64", http.StatusBadRequest) |
64 | return |
65 | } |
66 | |
67 | |
68 | userID := hex.EncodeToString(valueBytes) |
69 | if userID == "" { |
70 | log.Printf("Empty user ID after decoding") |
71 | http.Error(w, "Unauthorized", http.StatusUnauthorized) |
72 | return |
73 | } |
74 | |
75 | log.Printf("Authenticated user ID: %s (raw bytes: %x, encoded: %s)", userID, valueBytes, encodedValue) |
76 | |
77 | |
78 | ctx := context.WithValue(r.Context(), userIDContextKey, userID) |
79 | next.ServeHTTP(w, r.WithContext(ctx)) |
80 | }) |
81 | } |
82 | |
83 | |
84 | func GetUserIDFromContext(ctx context.Context) (string, bool) { |
85 | userID, ok := ctx.Value(userIDContextKey).(string) |
86 | return userID, ok |
87 | } |