Files

301 lines
7.3 KiB
Go

package ws
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"filefast/backend/internal/model"
"filefast/backend/internal/service"
"filefast/backend/internal/storage"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
const (
writeWait = 10 * time.Second
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10
maxMessageSize = 8 * 1024 * 1024
)
type Hub struct {
logger *slog.Logger
deviceService *service.DeviceService
backplane realtimeBackplane
instanceID string
clients map[string]*Client
register chan *Client
unregister chan *Client
relay chan relayMessage
}
type Client struct {
hub *Hub
deviceID string
conn *websocket.Conn
send chan []byte
}
type relayMessage struct {
targetDeviceID string
payload []byte
}
type inboundMessage struct {
Type string `json:"type"`
TargetDeviceID string `json:"target_device_id"`
Payload json.RawMessage `json:"payload"`
}
type backplaneEnvelope struct {
Source string `json:"source"`
TargetDeviceID string `json:"target_device_id,omitempty"`
Payload json.RawMessage `json:"payload"`
}
type realtimeBackplane interface {
PublishRealtime(context.Context, string, []byte) error
SubscribeRealtime(context.Context, ...string) (<-chan storage.RealtimeMessage, error)
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
// Device sessions are authenticated before upgrade. Keeping origin
// permissive avoids false negatives behind reverse proxies or custom
// domains where Host/Forwarded headers are rewritten.
return true
},
}
func NewHub(logger *slog.Logger, deviceService *service.DeviceService, backplane realtimeBackplane) *Hub {
return &Hub{
logger: logger,
deviceService: deviceService,
backplane: backplane,
instanceID: fmt.Sprintf("%d", time.Now().UnixNano()),
clients: make(map[string]*Client),
register: make(chan *Client),
unregister: make(chan *Client),
relay: make(chan relayMessage),
}
}
func (h *Hub) Run() {
var backplaneMessages <-chan storage.RealtimeMessage
if h.backplane != nil {
messages, err := h.backplane.SubscribeRealtime(context.Background(), storage.RedisRealtimeRelayChannel, storage.RedisRealtimePresenceChannel)
if err != nil {
h.logger.Warn("failed to subscribe redis realtime backplane", "error", err)
} else {
backplaneMessages = messages
}
}
for {
select {
case client := <-h.register:
h.clients[client.deviceID] = client
h.deviceService.SetOnline(client.deviceID, true)
h.broadcastPresence(client.deviceID, true)
case client := <-h.unregister:
if existing, ok := h.clients[client.deviceID]; ok && existing == client {
delete(h.clients, client.deviceID)
close(client.send)
h.deviceService.SetOnline(client.deviceID, false)
h.broadcastPresence(client.deviceID, false)
}
case message := <-h.relay:
h.dispatchRelay(message.targetDeviceID, message.payload)
case message, ok := <-backplaneMessages:
if !ok {
backplaneMessages = nil
continue
}
h.handleBackplaneMessage(message)
}
}
}
func (h *Hub) Handle(c *gin.Context) {
deviceID := c.Query("deviceId")
token := strings.TrimSpace(c.Query("deviceToken"))
if deviceID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "deviceId is required"})
return
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "deviceToken is required"})
return
}
if !h.deviceService.ValidateSession(deviceID, token) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid device credentials"})
return
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to upgrade websocket"})
return
}
client := &Client{
hub: h,
deviceID: deviceID,
conn: conn,
send: make(chan []byte, 32),
}
h.register <- client
go client.writePump()
client.readPump()
}
func (h *Hub) broadcastPresence(deviceID string, online bool) {
envelope := model.SignalEnvelope{
Type: "presence.update",
DeviceID: deviceID,
Payload: map[string]any{
"online": online,
},
}
data, err := json.Marshal(envelope)
if err != nil {
h.logger.Warn("failed to marshal presence update", "error", err)
return
}
h.broadcastPresencePayload(data)
h.publishBackplane(storage.RedisRealtimePresenceChannel, "", data)
}
func (h *Hub) broadcastPresencePayload(data []byte) {
for _, client := range h.clients {
client.send <- data
}
}
func (h *Hub) dispatchRelay(targetDeviceID string, payload []byte) {
if client, ok := h.clients[targetDeviceID]; ok {
client.send <- payload
}
h.publishBackplane(storage.RedisRealtimeRelayChannel, targetDeviceID, payload)
}
func (h *Hub) publishBackplane(channel, targetDeviceID string, payload []byte) {
if h.backplane == nil {
return
}
envelope, err := json.Marshal(backplaneEnvelope{
Source: h.instanceID,
TargetDeviceID: targetDeviceID,
Payload: payload,
})
if err != nil {
h.logger.Warn("failed to marshal backplane envelope", "channel", channel, "error", err)
return
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := h.backplane.PublishRealtime(ctx, channel, envelope); err != nil {
h.logger.Warn("failed to publish realtime backplane message", "channel", channel, "error", err)
}
}
func (h *Hub) handleBackplaneMessage(message storage.RealtimeMessage) {
var envelope backplaneEnvelope
if err := json.Unmarshal(message.Payload, &envelope); err != nil {
h.logger.Warn("failed to decode realtime backplane message", "channel", message.Channel, "error", err)
return
}
if envelope.Source == h.instanceID {
return
}
switch message.Channel {
case storage.RedisRealtimeRelayChannel:
if client, ok := h.clients[envelope.TargetDeviceID]; ok {
client.send <- envelope.Payload
}
case storage.RedisRealtimePresenceChannel:
h.broadcastPresencePayload(envelope.Payload)
}
}
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
var inbound inboundMessage
if err := c.conn.ReadJSON(&inbound); err != nil {
return
}
envelope := model.SignalEnvelope{
Type: inbound.Type,
DeviceID: c.deviceID,
TargetDeviceID: inbound.TargetDeviceID,
Payload: inbound.Payload,
}
data, err := json.Marshal(envelope)
if err != nil {
continue
}
c.hub.relay <- relayMessage{
targetDeviceID: inbound.TargetDeviceID,
payload: data,
}
}
}
func (c *Client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}