package ws import ( "context" "encoding/json" "fmt" "log/slog" "net" "net/http" "net/url" "strings" "time" "filefast/backend/internal/model" "filefast/backend/internal/service" "filefast/backend/internal/storage" "github.com/gin-gonic/gin" "github.com/gorilla/websocket" ) const ( writeWait = 10 * time.Second pongWait = 60 * time.Second pingPeriod = (pongWait * 9) / 10 maxMessageSize = 8 * 1024 * 1024 ) type Hub struct { logger *slog.Logger deviceService *service.DeviceService backplane realtimeBackplane instanceID string clients map[string]*Client register chan *Client unregister chan *Client relay chan relayMessage } type Client struct { hub *Hub deviceID string conn *websocket.Conn send chan []byte } type relayMessage struct { targetDeviceID string payload []byte } type inboundMessage struct { Type string `json:"type"` TargetDeviceID string `json:"target_device_id"` Payload json.RawMessage `json:"payload"` } type backplaneEnvelope struct { Source string `json:"source"` TargetDeviceID string `json:"target_device_id,omitempty"` Payload json.RawMessage `json:"payload"` } type realtimeBackplane interface { PublishRealtime(context.Context, string, []byte) error SubscribeRealtime(context.Context, ...string) (<-chan storage.RealtimeMessage, error) } var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { origin := strings.TrimSpace(r.Header.Get("Origin")) if origin == "" { return true } parsed, err := url.Parse(origin) if err != nil { return false } return originHostMatchesRequest(parsed.Host, r) }, } func NewHub(logger *slog.Logger, deviceService *service.DeviceService, backplane realtimeBackplane) *Hub { return &Hub{ logger: logger, deviceService: deviceService, backplane: backplane, instanceID: fmt.Sprintf("%d", time.Now().UnixNano()), clients: make(map[string]*Client), register: make(chan *Client), unregister: make(chan *Client), relay: make(chan relayMessage), } } func (h *Hub) Run() { var backplaneMessages <-chan storage.RealtimeMessage if h.backplane != nil { messages, err := h.backplane.SubscribeRealtime(context.Background(), storage.RedisRealtimeRelayChannel, storage.RedisRealtimePresenceChannel) if err != nil { h.logger.Warn("failed to subscribe redis realtime backplane", "error", err) } else { backplaneMessages = messages } } for { select { case client := <-h.register: h.clients[client.deviceID] = client h.deviceService.SetOnline(client.deviceID, true) h.broadcastPresence(client.deviceID, true) case client := <-h.unregister: if existing, ok := h.clients[client.deviceID]; ok && existing == client { delete(h.clients, client.deviceID) close(client.send) h.deviceService.SetOnline(client.deviceID, false) h.broadcastPresence(client.deviceID, false) } case message := <-h.relay: h.dispatchRelay(message.targetDeviceID, message.payload) case message, ok := <-backplaneMessages: if !ok { backplaneMessages = nil continue } h.handleBackplaneMessage(message) } } } func (h *Hub) Handle(c *gin.Context) { deviceID := c.Query("deviceId") token := strings.TrimSpace(c.Query("deviceToken")) if deviceID == "" { c.JSON(http.StatusBadRequest, gin.H{"error": "deviceId is required"}) return } if token == "" { c.JSON(http.StatusUnauthorized, gin.H{"error": "deviceToken is required"}) return } if !h.deviceService.ValidateSession(deviceID, token) { c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid device credentials"}) return } conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": "failed to upgrade websocket"}) return } client := &Client{ hub: h, deviceID: deviceID, conn: conn, send: make(chan []byte, 32), } h.register <- client go client.writePump() client.readPump() } func (h *Hub) broadcastPresence(deviceID string, online bool) { envelope := model.SignalEnvelope{ Type: "presence.update", DeviceID: deviceID, Payload: map[string]any{ "online": online, }, } data, err := json.Marshal(envelope) if err != nil { h.logger.Warn("failed to marshal presence update", "error", err) return } h.broadcastPresencePayload(data) h.publishBackplane(storage.RedisRealtimePresenceChannel, "", data) } func (h *Hub) broadcastPresencePayload(data []byte) { for _, client := range h.clients { client.send <- data } } func (h *Hub) dispatchRelay(targetDeviceID string, payload []byte) { if client, ok := h.clients[targetDeviceID]; ok { client.send <- payload } h.publishBackplane(storage.RedisRealtimeRelayChannel, targetDeviceID, payload) } func (h *Hub) publishBackplane(channel, targetDeviceID string, payload []byte) { if h.backplane == nil { return } envelope, err := json.Marshal(backplaneEnvelope{ Source: h.instanceID, TargetDeviceID: targetDeviceID, Payload: payload, }) if err != nil { h.logger.Warn("failed to marshal backplane envelope", "channel", channel, "error", err) return } ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if err := h.backplane.PublishRealtime(ctx, channel, envelope); err != nil { h.logger.Warn("failed to publish realtime backplane message", "channel", channel, "error", err) } } func (h *Hub) handleBackplaneMessage(message storage.RealtimeMessage) { var envelope backplaneEnvelope if err := json.Unmarshal(message.Payload, &envelope); err != nil { h.logger.Warn("failed to decode realtime backplane message", "channel", message.Channel, "error", err) return } if envelope.Source == h.instanceID { return } switch message.Channel { case storage.RedisRealtimeRelayChannel: if client, ok := h.clients[envelope.TargetDeviceID]; ok { client.send <- envelope.Payload } case storage.RedisRealtimePresenceChannel: h.broadcastPresencePayload(envelope.Payload) } } func (c *Client) readPump() { defer func() { c.hub.unregister <- c c.conn.Close() }() c.conn.SetReadLimit(maxMessageSize) c.conn.SetReadDeadline(time.Now().Add(pongWait)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)) return nil }) for { var inbound inboundMessage if err := c.conn.ReadJSON(&inbound); err != nil { return } envelope := model.SignalEnvelope{ Type: inbound.Type, DeviceID: c.deviceID, TargetDeviceID: inbound.TargetDeviceID, Payload: inbound.Payload, } data, err := json.Marshal(envelope) if err != nil { continue } c.hub.relay <- relayMessage{ targetDeviceID: inbound.TargetDeviceID, payload: data, } } } func (c *Client) writePump() { ticker := time.NewTicker(pingPeriod) defer func() { ticker.Stop() c.conn.Close() }() for { select { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if !ok { c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil { return } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(writeWait)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } func originHostMatchesRequest(originHost string, r *http.Request) bool { requestHosts := []string{ strings.TrimSpace(r.Host), strings.TrimSpace(r.Header.Get("X-Forwarded-Host")), } originName, err := normalizeHost(originHost) if err != nil { return false } for _, host := range requestHosts { if host == "" { continue } requestName, err := normalizeHost(host) if err != nil { continue } if requestName == originName { return true } } return false } func normalizeHost(host string) (string, error) { host = strings.TrimSpace(host) if host == "" { return "", fmt.Errorf("empty host") } if strings.Contains(host, "://") { parsed, err := url.Parse(host) if err != nil { return "", err } host = parsed.Host } name, _, err := net.SplitHostPort(host) if err == nil { host = name } else if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { host = strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[") } host = strings.ToLower(strings.TrimSpace(host)) switch host { case "127.0.0.1", "::1": return "localhost", nil default: return host, nil } }