301 lines
7.3 KiB
Go
301 lines
7.3 KiB
Go
package ws
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log/slog"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"filefast/backend/internal/model"
|
|
"filefast/backend/internal/service"
|
|
"filefast/backend/internal/storage"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const (
|
|
writeWait = 10 * time.Second
|
|
pongWait = 60 * time.Second
|
|
pingPeriod = (pongWait * 9) / 10
|
|
maxMessageSize = 8 * 1024 * 1024
|
|
)
|
|
|
|
type Hub struct {
|
|
logger *slog.Logger
|
|
deviceService *service.DeviceService
|
|
backplane realtimeBackplane
|
|
instanceID string
|
|
clients map[string]*Client
|
|
register chan *Client
|
|
unregister chan *Client
|
|
relay chan relayMessage
|
|
}
|
|
|
|
type Client struct {
|
|
hub *Hub
|
|
deviceID string
|
|
conn *websocket.Conn
|
|
send chan []byte
|
|
}
|
|
|
|
type relayMessage struct {
|
|
targetDeviceID string
|
|
payload []byte
|
|
}
|
|
|
|
type inboundMessage struct {
|
|
Type string `json:"type"`
|
|
TargetDeviceID string `json:"target_device_id"`
|
|
Payload json.RawMessage `json:"payload"`
|
|
}
|
|
|
|
type backplaneEnvelope struct {
|
|
Source string `json:"source"`
|
|
TargetDeviceID string `json:"target_device_id,omitempty"`
|
|
Payload json.RawMessage `json:"payload"`
|
|
}
|
|
|
|
type realtimeBackplane interface {
|
|
PublishRealtime(context.Context, string, []byte) error
|
|
SubscribeRealtime(context.Context, ...string) (<-chan storage.RealtimeMessage, error)
|
|
}
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
// Device sessions are authenticated before upgrade. Keeping origin
|
|
// permissive avoids false negatives behind reverse proxies or custom
|
|
// domains where Host/Forwarded headers are rewritten.
|
|
return true
|
|
},
|
|
}
|
|
|
|
func NewHub(logger *slog.Logger, deviceService *service.DeviceService, backplane realtimeBackplane) *Hub {
|
|
return &Hub{
|
|
logger: logger,
|
|
deviceService: deviceService,
|
|
backplane: backplane,
|
|
instanceID: fmt.Sprintf("%d", time.Now().UnixNano()),
|
|
clients: make(map[string]*Client),
|
|
register: make(chan *Client),
|
|
unregister: make(chan *Client),
|
|
relay: make(chan relayMessage),
|
|
}
|
|
}
|
|
|
|
func (h *Hub) Run() {
|
|
var backplaneMessages <-chan storage.RealtimeMessage
|
|
if h.backplane != nil {
|
|
messages, err := h.backplane.SubscribeRealtime(context.Background(), storage.RedisRealtimeRelayChannel, storage.RedisRealtimePresenceChannel)
|
|
if err != nil {
|
|
h.logger.Warn("failed to subscribe redis realtime backplane", "error", err)
|
|
} else {
|
|
backplaneMessages = messages
|
|
}
|
|
}
|
|
|
|
for {
|
|
select {
|
|
case client := <-h.register:
|
|
h.clients[client.deviceID] = client
|
|
h.deviceService.SetOnline(client.deviceID, true)
|
|
h.broadcastPresence(client.deviceID, true)
|
|
case client := <-h.unregister:
|
|
if existing, ok := h.clients[client.deviceID]; ok && existing == client {
|
|
delete(h.clients, client.deviceID)
|
|
close(client.send)
|
|
h.deviceService.SetOnline(client.deviceID, false)
|
|
h.broadcastPresence(client.deviceID, false)
|
|
}
|
|
case message := <-h.relay:
|
|
h.dispatchRelay(message.targetDeviceID, message.payload)
|
|
case message, ok := <-backplaneMessages:
|
|
if !ok {
|
|
backplaneMessages = nil
|
|
continue
|
|
}
|
|
h.handleBackplaneMessage(message)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Hub) Handle(c *gin.Context) {
|
|
deviceID := c.Query("deviceId")
|
|
token := strings.TrimSpace(c.Query("deviceToken"))
|
|
if deviceID == "" {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "deviceId is required"})
|
|
return
|
|
}
|
|
if token == "" {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "deviceToken is required"})
|
|
return
|
|
}
|
|
if !h.deviceService.ValidateSession(deviceID, token) {
|
|
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid device credentials"})
|
|
return
|
|
}
|
|
|
|
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to upgrade websocket"})
|
|
return
|
|
}
|
|
|
|
client := &Client{
|
|
hub: h,
|
|
deviceID: deviceID,
|
|
conn: conn,
|
|
send: make(chan []byte, 32),
|
|
}
|
|
|
|
h.register <- client
|
|
go client.writePump()
|
|
client.readPump()
|
|
}
|
|
|
|
func (h *Hub) broadcastPresence(deviceID string, online bool) {
|
|
envelope := model.SignalEnvelope{
|
|
Type: "presence.update",
|
|
DeviceID: deviceID,
|
|
Payload: map[string]any{
|
|
"online": online,
|
|
},
|
|
}
|
|
|
|
data, err := json.Marshal(envelope)
|
|
if err != nil {
|
|
h.logger.Warn("failed to marshal presence update", "error", err)
|
|
return
|
|
}
|
|
|
|
h.broadcastPresencePayload(data)
|
|
h.publishBackplane(storage.RedisRealtimePresenceChannel, "", data)
|
|
}
|
|
|
|
func (h *Hub) broadcastPresencePayload(data []byte) {
|
|
for _, client := range h.clients {
|
|
client.send <- data
|
|
}
|
|
}
|
|
|
|
func (h *Hub) dispatchRelay(targetDeviceID string, payload []byte) {
|
|
if client, ok := h.clients[targetDeviceID]; ok {
|
|
client.send <- payload
|
|
}
|
|
|
|
h.publishBackplane(storage.RedisRealtimeRelayChannel, targetDeviceID, payload)
|
|
}
|
|
|
|
func (h *Hub) publishBackplane(channel, targetDeviceID string, payload []byte) {
|
|
if h.backplane == nil {
|
|
return
|
|
}
|
|
|
|
envelope, err := json.Marshal(backplaneEnvelope{
|
|
Source: h.instanceID,
|
|
TargetDeviceID: targetDeviceID,
|
|
Payload: payload,
|
|
})
|
|
if err != nil {
|
|
h.logger.Warn("failed to marshal backplane envelope", "channel", channel, "error", err)
|
|
return
|
|
}
|
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
defer cancel()
|
|
|
|
if err := h.backplane.PublishRealtime(ctx, channel, envelope); err != nil {
|
|
h.logger.Warn("failed to publish realtime backplane message", "channel", channel, "error", err)
|
|
}
|
|
}
|
|
|
|
func (h *Hub) handleBackplaneMessage(message storage.RealtimeMessage) {
|
|
var envelope backplaneEnvelope
|
|
if err := json.Unmarshal(message.Payload, &envelope); err != nil {
|
|
h.logger.Warn("failed to decode realtime backplane message", "channel", message.Channel, "error", err)
|
|
return
|
|
}
|
|
|
|
if envelope.Source == h.instanceID {
|
|
return
|
|
}
|
|
|
|
switch message.Channel {
|
|
case storage.RedisRealtimeRelayChannel:
|
|
if client, ok := h.clients[envelope.TargetDeviceID]; ok {
|
|
client.send <- envelope.Payload
|
|
}
|
|
case storage.RedisRealtimePresenceChannel:
|
|
h.broadcastPresencePayload(envelope.Payload)
|
|
}
|
|
}
|
|
|
|
func (c *Client) readPump() {
|
|
defer func() {
|
|
c.hub.unregister <- c
|
|
c.conn.Close()
|
|
}()
|
|
|
|
c.conn.SetReadLimit(maxMessageSize)
|
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
c.conn.SetPongHandler(func(string) error {
|
|
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
var inbound inboundMessage
|
|
if err := c.conn.ReadJSON(&inbound); err != nil {
|
|
return
|
|
}
|
|
|
|
envelope := model.SignalEnvelope{
|
|
Type: inbound.Type,
|
|
DeviceID: c.deviceID,
|
|
TargetDeviceID: inbound.TargetDeviceID,
|
|
Payload: inbound.Payload,
|
|
}
|
|
|
|
data, err := json.Marshal(envelope)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
c.hub.relay <- relayMessage{
|
|
targetDeviceID: inbound.TargetDeviceID,
|
|
payload: data,
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Client) writePump() {
|
|
ticker := time.NewTicker(pingPeriod)
|
|
defer func() {
|
|
ticker.Stop()
|
|
c.conn.Close()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.send:
|
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
if !ok {
|
|
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
|
|
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
|
return
|
|
}
|
|
case <-ticker.C:
|
|
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
|
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|