first commit
This commit is contained in:
368
backend/internal/ws/hub.go
Normal file
368
backend/internal/ws/hub.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/model"
|
||||
"filefast/backend/internal/service"
|
||||
"filefast/backend/internal/storage"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
writeWait = 10 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
maxMessageSize = 8 * 1024 * 1024
|
||||
)
|
||||
|
||||
type Hub struct {
|
||||
logger *slog.Logger
|
||||
deviceService *service.DeviceService
|
||||
backplane realtimeBackplane
|
||||
instanceID string
|
||||
clients map[string]*Client
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
relay chan relayMessage
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
hub *Hub
|
||||
deviceID string
|
||||
conn *websocket.Conn
|
||||
send chan []byte
|
||||
}
|
||||
|
||||
type relayMessage struct {
|
||||
targetDeviceID string
|
||||
payload []byte
|
||||
}
|
||||
|
||||
type inboundMessage struct {
|
||||
Type string `json:"type"`
|
||||
TargetDeviceID string `json:"target_device_id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
type backplaneEnvelope struct {
|
||||
Source string `json:"source"`
|
||||
TargetDeviceID string `json:"target_device_id,omitempty"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
type realtimeBackplane interface {
|
||||
PublishRealtime(context.Context, string, []byte) error
|
||||
SubscribeRealtime(context.Context, ...string) (<-chan storage.RealtimeMessage, error)
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return originHostMatchesRequest(parsed.Host, r)
|
||||
},
|
||||
}
|
||||
|
||||
func NewHub(logger *slog.Logger, deviceService *service.DeviceService, backplane realtimeBackplane) *Hub {
|
||||
return &Hub{
|
||||
logger: logger,
|
||||
deviceService: deviceService,
|
||||
backplane: backplane,
|
||||
instanceID: fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
clients: make(map[string]*Client),
|
||||
register: make(chan *Client),
|
||||
unregister: make(chan *Client),
|
||||
relay: make(chan relayMessage),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) Run() {
|
||||
var backplaneMessages <-chan storage.RealtimeMessage
|
||||
if h.backplane != nil {
|
||||
messages, err := h.backplane.SubscribeRealtime(context.Background(), storage.RedisRealtimeRelayChannel, storage.RedisRealtimePresenceChannel)
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to subscribe redis realtime backplane", "error", err)
|
||||
} else {
|
||||
backplaneMessages = messages
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case client := <-h.register:
|
||||
h.clients[client.deviceID] = client
|
||||
h.deviceService.SetOnline(client.deviceID, true)
|
||||
h.broadcastPresence(client.deviceID, true)
|
||||
case client := <-h.unregister:
|
||||
if existing, ok := h.clients[client.deviceID]; ok && existing == client {
|
||||
delete(h.clients, client.deviceID)
|
||||
close(client.send)
|
||||
h.deviceService.SetOnline(client.deviceID, false)
|
||||
h.broadcastPresence(client.deviceID, false)
|
||||
}
|
||||
case message := <-h.relay:
|
||||
h.dispatchRelay(message.targetDeviceID, message.payload)
|
||||
case message, ok := <-backplaneMessages:
|
||||
if !ok {
|
||||
backplaneMessages = nil
|
||||
continue
|
||||
}
|
||||
h.handleBackplaneMessage(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) Handle(c *gin.Context) {
|
||||
deviceID := c.Query("deviceId")
|
||||
token := strings.TrimSpace(c.Query("deviceToken"))
|
||||
if deviceID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "deviceId is required"})
|
||||
return
|
||||
}
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "deviceToken is required"})
|
||||
return
|
||||
}
|
||||
if !h.deviceService.ValidateSession(deviceID, token) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid device credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to upgrade websocket"})
|
||||
return
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
hub: h,
|
||||
deviceID: deviceID,
|
||||
conn: conn,
|
||||
send: make(chan []byte, 32),
|
||||
}
|
||||
|
||||
h.register <- client
|
||||
go client.writePump()
|
||||
client.readPump()
|
||||
}
|
||||
|
||||
func (h *Hub) broadcastPresence(deviceID string, online bool) {
|
||||
envelope := model.SignalEnvelope{
|
||||
Type: "presence.update",
|
||||
DeviceID: deviceID,
|
||||
Payload: map[string]any{
|
||||
"online": online,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to marshal presence update", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
h.broadcastPresencePayload(data)
|
||||
h.publishBackplane(storage.RedisRealtimePresenceChannel, "", data)
|
||||
}
|
||||
|
||||
func (h *Hub) broadcastPresencePayload(data []byte) {
|
||||
for _, client := range h.clients {
|
||||
client.send <- data
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) dispatchRelay(targetDeviceID string, payload []byte) {
|
||||
if client, ok := h.clients[targetDeviceID]; ok {
|
||||
client.send <- payload
|
||||
}
|
||||
|
||||
h.publishBackplane(storage.RedisRealtimeRelayChannel, targetDeviceID, payload)
|
||||
}
|
||||
|
||||
func (h *Hub) publishBackplane(channel, targetDeviceID string, payload []byte) {
|
||||
if h.backplane == nil {
|
||||
return
|
||||
}
|
||||
|
||||
envelope, err := json.Marshal(backplaneEnvelope{
|
||||
Source: h.instanceID,
|
||||
TargetDeviceID: targetDeviceID,
|
||||
Payload: payload,
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to marshal backplane envelope", "channel", channel, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := h.backplane.PublishRealtime(ctx, channel, envelope); err != nil {
|
||||
h.logger.Warn("failed to publish realtime backplane message", "channel", channel, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) handleBackplaneMessage(message storage.RealtimeMessage) {
|
||||
var envelope backplaneEnvelope
|
||||
if err := json.Unmarshal(message.Payload, &envelope); err != nil {
|
||||
h.logger.Warn("failed to decode realtime backplane message", "channel", message.Channel, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if envelope.Source == h.instanceID {
|
||||
return
|
||||
}
|
||||
|
||||
switch message.Channel {
|
||||
case storage.RedisRealtimeRelayChannel:
|
||||
if client, ok := h.clients[envelope.TargetDeviceID]; ok {
|
||||
client.send <- envelope.Payload
|
||||
}
|
||||
case storage.RedisRealtimePresenceChannel:
|
||||
h.broadcastPresencePayload(envelope.Payload)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) readPump() {
|
||||
defer func() {
|
||||
c.hub.unregister <- c
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
c.conn.SetReadLimit(maxMessageSize)
|
||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
c.conn.SetPongHandler(func(string) error {
|
||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
var inbound inboundMessage
|
||||
if err := c.conn.ReadJSON(&inbound); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
envelope := model.SignalEnvelope{
|
||||
Type: inbound.Type,
|
||||
DeviceID: c.deviceID,
|
||||
TargetDeviceID: inbound.TargetDeviceID,
|
||||
Payload: inbound.Payload,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
c.hub.relay <- relayMessage{
|
||||
targetDeviceID: inbound.TargetDeviceID,
|
||||
payload: data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) writePump() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.send:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if !ok {
|
||||
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func originHostMatchesRequest(originHost string, r *http.Request) bool {
|
||||
requestHosts := []string{
|
||||
strings.TrimSpace(r.Host),
|
||||
strings.TrimSpace(r.Header.Get("X-Forwarded-Host")),
|
||||
}
|
||||
|
||||
originName, err := normalizeHost(originHost)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, host := range requestHosts {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
requestName, err := normalizeHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if requestName == originName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeHost(host string) (string, error) {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("empty host")
|
||||
}
|
||||
|
||||
if strings.Contains(host, "://") {
|
||||
parsed, err := url.Parse(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
host = parsed.Host
|
||||
}
|
||||
|
||||
name, _, err := net.SplitHostPort(host)
|
||||
if err == nil {
|
||||
host = name
|
||||
} else if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
host = strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
|
||||
}
|
||||
|
||||
host = strings.ToLower(strings.TrimSpace(host))
|
||||
switch host {
|
||||
case "127.0.0.1", "::1":
|
||||
return "localhost", nil
|
||||
default:
|
||||
return host, nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user