first commit

This commit is contained in:
2026-03-28 15:43:18 +08:00
commit e5611df24e
54 changed files with 11065 additions and 0 deletions

View File

@@ -0,0 +1,252 @@
package config
import (
"bufio"
"log/slog"
"os"
"path/filepath"
"slices"
"strconv"
"strings"
"time"
"filefast/backend/internal/model"
)
type Config struct {
HTTPAddress string
LogLevel slog.Level
RoomTTL time.Duration
Admin AdminConfig
SQLite SQLiteConfig
Redis RedisConfig
MinIO MinIOConfig
Runtime model.RuntimeConfig
}
type AdminConfig struct {
Username string
Password string
}
type MinIOConfig struct {
Endpoint string
AccessKey string
SecretKey string
UseSSL bool
Bucket string
PresignExpiry time.Duration
Retention time.Duration
UsageAlertLevel int
}
type RedisConfig struct {
Addr string
Password string
DB int
}
type SQLiteConfig struct {
Path string
}
func Load() Config {
loadDotEnv()
retentionHours := envInt("MINIO_RETENTION_HOURS", 2)
capacityGB := envInt("MINIO_CAPACITY_GB", 120)
return Config{
HTTPAddress: envString("HTTP_ADDR", ":8080"),
LogLevel: parseLogLevel(envString("LOG_LEVEL", "info")),
RoomTTL: time.Duration(envInt("ROOM_TTL_SECONDS", 300)) * time.Second,
Admin: AdminConfig{
Username: envString("ADMIN_USERNAME", ""),
Password: envString("ADMIN_PASSWORD", ""),
},
SQLite: SQLiteConfig{
Path: envString("SQLITE_PATH", filepath.Join("backend", "data", "filefast.db")),
},
Redis: RedisConfig{
Addr: envString("REDIS_ADDR", "127.0.0.1:6379"),
Password: envString("REDIS_PASSWORD", ""),
DB: envInt("REDIS_DB", 0),
},
MinIO: MinIOConfig{
Endpoint: envString("MINIO_ENDPOINT", ""),
AccessKey: envString("MINIO_ACCESS_KEY", ""),
SecretKey: envString("MINIO_SECRET_KEY", ""),
UseSSL: envBool("MINIO_USE_SSL", false),
Bucket: envString("MINIO_BUCKET", "filefast-fallback"),
PresignExpiry: time.Duration(envInt("MINIO_PRESIGN_MINUTES", 30)) * time.Minute,
Retention: time.Duration(retentionHours) * time.Hour,
UsageAlertLevel: envInt("MINIO_USAGE_ALERT_PERCENT", 85),
},
Runtime: model.RuntimeConfig{
MaxMinIOFallbackSizeBytes: int64(envInt("MAX_MINIO_FALLBACK_GB", 10)) * 1024 * 1024 * 1024,
MinIOCapacityBytes: int64(capacityGB) * 1024 * 1024 * 1024,
MinIORetentionHours: retentionHours,
MinIOUsageAlertPercent: envInt("MINIO_USAGE_ALERT_PERCENT", 85),
P2PConnectTimeoutSec: envInt("P2P_CONNECT_TIMEOUT_SEC", 15),
TURNConnectTimeoutSec: envInt("TURN_CONNECT_TIMEOUT_SEC", 20),
MinIOFallbackEnabled: true,
TURNURLs: envCSV("TURN_URLS"),
TURNUsername: envString("TURN_USERNAME", ""),
TURNPassword: envString("TURN_PASSWORD", ""),
},
}
}
func loadDotEnv() {
for _, candidate := range dotEnvCandidates() {
if loadDotEnvFile(candidate) {
return
}
}
}
func dotEnvCandidates() []string {
cwd, err := os.Getwd()
if err != nil {
return []string{".env", filepath.Join("backend", ".env")}
}
candidates := make([]string, 0, 16)
seen := make(map[string]struct{})
current := cwd
for {
for _, name := range []string{".env", filepath.Join("backend", ".env")} {
candidate := filepath.Clean(filepath.Join(current, name))
if _, ok := seen[candidate]; ok {
continue
}
seen[candidate] = struct{}{}
candidates = append(candidates, candidate)
}
parent := filepath.Dir(current)
if parent == current {
break
}
current = parent
}
// Preserve the most local candidates first, but keep the historical relative
// fallbacks at the end for compatibility.
for _, legacy := range []string{".env", filepath.Join("backend", ".env")} {
if !slices.Contains(candidates, legacy) {
candidates = append(candidates, legacy)
}
}
return candidates
}
func loadDotEnvFile(path string) bool {
file, err := os.Open(path)
if err != nil {
return false
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(strings.TrimPrefix(scanner.Text(), "\uFEFF"))
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if strings.HasPrefix(line, "export ") {
line = strings.TrimSpace(strings.TrimPrefix(line, "export "))
}
key, value, ok := strings.Cut(line, "=")
if !ok {
continue
}
key = strings.TrimSpace(key)
value = strings.TrimSpace(value)
if key == "" {
continue
}
if _, exists := os.LookupEnv(key); exists {
continue
}
if len(value) >= 2 {
if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) ||
(strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) {
value = value[1 : len(value)-1]
}
}
_ = os.Setenv(key, value)
}
return true
}
func envString(key, fallback string) string {
value := strings.TrimSpace(os.Getenv(key))
if value == "" {
return fallback
}
return value
}
func envInt(key string, fallback int) int {
value := strings.TrimSpace(os.Getenv(key))
if value == "" {
return fallback
}
parsed, err := strconv.Atoi(value)
if err != nil {
return fallback
}
return parsed
}
func envBool(key string, fallback bool) bool {
value := strings.TrimSpace(strings.ToLower(os.Getenv(key)))
if value == "" {
return fallback
}
return value == "1" || value == "true" || value == "yes"
}
func envCSV(key string) []string {
raw := strings.TrimSpace(os.Getenv(key))
if raw == "" {
return nil
}
parts := strings.Split(raw, ",")
values := make([]string, 0, len(parts))
for _, part := range parts {
value := strings.TrimSpace(part)
if value != "" {
values = append(values, value)
}
}
if len(values) == 0 {
return nil
}
return values
}
func parseLogLevel(level string) slog.Level {
switch strings.ToLower(strings.TrimSpace(level)) {
case "debug":
return slog.LevelDebug
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
default:
return slog.LevelInfo
}
}

View File

@@ -0,0 +1,593 @@
package handler
import (
"context"
"net/http"
"net/url"
"path/filepath"
"strings"
"time"
"filefast/backend/internal/config"
"filefast/backend/internal/model"
"filefast/backend/internal/service"
"filefast/backend/internal/storage"
"filefast/backend/internal/store"
"filefast/backend/internal/ws"
"github.com/gin-gonic/gin"
"log/slog"
)
type Dependencies struct {
Config config.Config
Logger *slog.Logger
Store *store.MemoryStore
DeviceService *service.DeviceService
RoomService *service.RoomService
TransferService *service.TransferService
AdminService *service.AdminService
MinIOClient *storage.MinIOClient
Hub *ws.Hub
StorageReady bool
RedisReady bool
}
type HTTPHandler struct {
deps Dependencies
}
func NewHTTPHandler(deps Dependencies) *HTTPHandler {
return &HTTPHandler{deps: deps}
}
func (h *HTTPHandler) Router() *gin.Engine {
router := gin.Default()
router.GET("/healthz", h.handleHealth)
router.GET("/ws", h.deps.Hub.Handle)
api := router.Group("/api")
{
api.GET("/runtime/config", h.runtimeConfig)
api.POST("/devices/register", h.registerDevice)
api.POST("/admin/login", h.adminLogin)
}
device := api.Group("/")
device.Use(h.requireDevice())
{
device.POST("/devices/heartbeat", h.deviceHeartbeat)
device.GET("/devices/candidates", h.listCandidates)
device.GET("/devices/:id/pending-downloads", h.pendingFallbackDownloads)
device.POST("/rooms", h.createRoom)
device.GET("/rooms/:code", h.getRoom)
device.POST("/rooms/join", h.joinRoom)
device.POST("/rooms/:code/cancel", h.cancelRoom)
device.POST("/transfers", h.createTransfer)
device.PATCH("/transfers/:id/status", h.updateTransferStatus)
device.POST("/transfers/:id/fallback/presign", h.presignFallback)
device.PUT("/transfers/:id/fallback/upload", h.uploadFallback)
device.GET("/transfers/:id/fallback/download", h.downloadFallback)
}
admin := api.Group("/admin")
admin.Use(h.requireAdmin())
{
admin.GET("/stats", h.adminStats)
admin.GET("/config", h.adminConfig)
admin.PUT("/config", h.updateAdminConfig)
admin.GET("/transfers/recent", h.recentTransfers)
}
return router
}
func (h *HTTPHandler) handleHealth(c *gin.Context) {
status := "ok"
if !h.deps.StorageReady || !h.deps.RedisReady {
status = "degraded"
}
c.JSON(http.StatusOK, gin.H{
"status": status,
"minio_enabled": h.deps.MinIOClient != nil && h.deps.MinIOClient.Enabled(),
"storage_ready": h.deps.StorageReady,
"redis_ready": h.deps.RedisReady,
"turn_enabled": len(h.deps.Store.RuntimeConfig().TURNURLs) > 0,
"room_ttl_sec": int(h.deps.Config.RoomTTL.Seconds()),
"server_time_unix": time.Now().Unix(),
})
}
func (h *HTTPHandler) runtimeConfig(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"data": h.deps.Store.RuntimeConfig()})
}
func (h *HTTPHandler) registerDevice(c *gin.Context) {
var input service.RegisterDeviceInput
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
device, session := h.deps.DeviceService.Register(input, c.Request.UserAgent(), c.GetHeader("X-Device-Token"))
c.JSON(http.StatusOK, gin.H{"data": gin.H{
"id": device.ID,
"name": device.Name,
"type": device.Type,
"user_agent": device.UserAgent,
"network_group_key": device.NetworkGroupKey,
"public_ip_hash": device.PublicIPHash,
"is_online": device.IsOnline,
"last_seen_at": device.LastSeenAt,
"created_at": device.CreatedAt,
"auth_token": session.Token,
"auth_expires_at": session.ExpiresAt,
}})
}
func (h *HTTPHandler) deviceHeartbeat(c *gin.Context) {
var input struct {
DeviceID string `json:"device_id" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !h.ensureAuthenticatedDevice(c, input.DeviceID) {
return
}
device, ok := h.deps.DeviceService.Heartbeat(input.DeviceID)
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "device not found"})
return
}
c.JSON(http.StatusOK, gin.H{"data": device})
}
func (h *HTTPHandler) listCandidates(c *gin.Context) {
deviceID := c.Query("deviceId")
if deviceID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "deviceId is required"})
return
}
if !h.ensureAuthenticatedDevice(c, deviceID) {
return
}
c.JSON(http.StatusOK, gin.H{"data": h.deps.DeviceService.ListCandidates(deviceID)})
}
func (h *HTTPHandler) createRoom(c *gin.Context) {
var input struct {
CreatorDeviceID string `json:"creator_device_id" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !h.ensureAuthenticatedDevice(c, input.CreatorDeviceID) {
return
}
room, err := h.deps.RoomService.CreateRoom(input.CreatorDeviceID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"data": room})
}
func (h *HTTPHandler) getRoom(c *gin.Context) {
room, ok := h.deps.Store.GetRoom(c.Param("code"))
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "room not found"})
return
}
deviceID := h.authenticatedDeviceID(c)
if room.CreatorDeviceID != deviceID && room.JoinerDeviceID != deviceID {
c.JSON(http.StatusForbidden, gin.H{"error": "room access denied"})
return
}
c.JSON(http.StatusOK, gin.H{"data": room})
}
func (h *HTTPHandler) joinRoom(c *gin.Context) {
var input struct {
Code string `json:"code" binding:"required"`
JoinerDeviceID string `json:"joiner_device_id" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !h.ensureAuthenticatedDevice(c, input.JoinerDeviceID) {
return
}
room, err := h.deps.RoomService.JoinRoom(input.Code, input.JoinerDeviceID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"data": room})
}
func (h *HTTPHandler) cancelRoom(c *gin.Context) {
var input struct {
RequesterID string `json:"requester_id" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !h.ensureAuthenticatedDevice(c, input.RequesterID) {
return
}
room, err := h.deps.RoomService.CancelRoom(c.Param("code"), input.RequesterID)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"data": room})
}
func (h *HTTPHandler) createTransfer(c *gin.Context) {
var input service.CreateTransferInput
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if !h.ensureAuthenticatedDevice(c, input.SenderDeviceID) {
return
}
transfer, err := h.deps.TransferService.Create(input)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"data": transfer})
}
func (h *HTTPHandler) updateTransferStatus(c *gin.Context) {
transfer, ok := h.deps.Store.GetTransfer(c.Param("id"))
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "transfer not found"})
return
}
deviceID := h.authenticatedDeviceID(c)
if transfer.SenderDeviceID != deviceID && transfer.ReceiverDeviceID != deviceID {
c.JSON(http.StatusForbidden, gin.H{"error": "transfer access denied"})
return
}
var input service.UpdateTransferStatusInput
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
transfer, err := h.deps.TransferService.UpdateStatus(c.Param("id"), input)
if err != nil {
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"data": transfer})
}
func (h *HTTPHandler) presignFallback(c *gin.Context) {
if h.deps.MinIOClient == nil || !h.deps.MinIOClient.Enabled() {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "minio fallback is disabled"})
return
}
transfer, ok := h.deps.Store.GetTransfer(c.Param("id"))
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "transfer not found"})
return
}
if transfer.SenderDeviceID != h.authenticatedDeviceID(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "transfer access denied"})
return
}
transfer, object, err := h.deps.TransferService.PrepareFallback(c.Param("id"))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
defer cancel()
if err := h.ensureFallbackBucket(ctx, transfer.ID); err != nil {
h.deps.Logger.Warn("minio ensure bucket failed", "transfer_id", transfer.ID, "error", err)
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
uploadURL, err := h.deps.MinIOClient.PresignUpload(ctx, object.ObjectKey)
if err != nil {
h.deps.Logger.Warn("minio presign upload failed", "transfer_id", transfer.ID, "object_key", object.ObjectKey, "error", err)
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
downloadURL, err := h.deps.MinIOClient.PresignDownload(ctx, object.ObjectKey)
if err != nil {
h.deps.Logger.Warn("minio presign download failed", "transfer_id", transfer.ID, "object_key", object.ObjectKey, "error", err)
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"data": gin.H{
"transfer": transfer,
"upload_url": uploadURL.String(),
"download_url": downloadURL.String(),
"download_path": fallbackDownloadPath(transfer.ID),
"expires_at": object.ExpiresAt,
},
})
}
func (h *HTTPHandler) uploadFallback(c *gin.Context) {
if h.deps.MinIOClient == nil || !h.deps.MinIOClient.Enabled() {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "minio fallback is disabled"})
return
}
transfer, ok := h.deps.Store.GetTransfer(c.Param("id"))
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "transfer not found"})
return
}
if transfer.ObjectKey == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "fallback object is not prepared"})
return
}
if transfer.SenderDeviceID != h.authenticatedDeviceID(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "transfer access denied"})
return
}
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Minute)
defer cancel()
if err := h.ensureFallbackBucket(ctx, transfer.ID); err != nil {
h.deps.Logger.Warn("minio ensure bucket failed", "transfer_id", transfer.ID, "error", err)
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
contentLength := c.Request.ContentLength
if contentLength <= 0 {
contentLength = transfer.SizeBytes
}
contentType := strings.TrimSpace(c.GetHeader("Content-Type"))
if contentType == "" {
contentType = "application/octet-stream"
}
if err := h.deps.MinIOClient.UploadObject(ctx, transfer.ObjectKey, c.Request.Body, contentLength, contentType); err != nil {
h.deps.Logger.Warn("minio upload object failed", "transfer_id", transfer.ID, "object_key", transfer.ObjectKey, "error", err)
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
if object, ok := h.deps.Store.GetFallbackObject(transfer.ID); ok {
object.SizeBytes = contentLength
object.CleanupState = "ready"
h.deps.Store.SaveFallbackObject(object)
}
downloadURL, err := h.deps.MinIOClient.PresignDownload(ctx, transfer.ObjectKey)
if err != nil {
h.deps.Logger.Warn("minio presign download failed after upload", "transfer_id", transfer.ID, "object_key", transfer.ObjectKey, "error", err)
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"data": gin.H{
"download_url": downloadURL.String(),
"download_path": fallbackDownloadPath(transfer.ID),
"object_key": transfer.ObjectKey,
},
})
}
func (h *HTTPHandler) pendingFallbackDownloads(c *gin.Context) {
deviceID := strings.TrimSpace(c.Param("id"))
if deviceID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "device id is required"})
return
}
if !h.ensureAuthenticatedDevice(c, deviceID) {
return
}
c.JSON(http.StatusOK, gin.H{
"data": h.deps.Store.ListPendingFallbackDownloads(deviceID, 20),
})
}
func (h *HTTPHandler) downloadFallback(c *gin.Context) {
if h.deps.MinIOClient == nil || !h.deps.MinIOClient.Enabled() {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "minio fallback is disabled"})
return
}
transfer, ok := h.deps.Store.GetTransfer(c.Param("id"))
if !ok {
c.JSON(http.StatusNotFound, gin.H{"error": "transfer not found"})
return
}
if transfer.ObjectKey == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "fallback object is not prepared"})
return
}
if transfer.ReceiverDeviceID != h.authenticatedDeviceID(c) {
c.JSON(http.StatusForbidden, gin.H{"error": "transfer access denied"})
return
}
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Minute)
defer cancel()
filename := filepath.Base(transfer.Name)
if filename == "." || filename == "" {
filename = "download.bin"
}
downloadURL, err := h.deps.MinIOClient.PresignDownloadWithFilename(ctx, transfer.ObjectKey, filename)
if err != nil {
h.deps.Logger.Warn("minio presign download failed", "transfer_id", transfer.ID, "object_key", transfer.ObjectKey, "error", err)
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
return
}
c.Redirect(http.StatusTemporaryRedirect, downloadURL.String())
}
func (h *HTTPHandler) adminLogin(c *gin.Context) {
var input struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
}
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
session, err := h.deps.AdminService.Login(input.Username, input.Password)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"data": session})
}
func (h *HTTPHandler) adminStats(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"data": gin.H{
"stats": h.deps.Store.SnapshotStats(),
"minio": h.deps.Store.SnapshotMinIOStorage(),
},
})
}
func (h *HTTPHandler) adminConfig(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"data": h.deps.Store.RuntimeConfig()})
}
func (h *HTTPHandler) updateAdminConfig(c *gin.Context) {
var input model.RuntimeConfig
if err := c.ShouldBindJSON(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"data": h.deps.Store.UpdateRuntimeConfig(input)})
}
func (h *HTTPHandler) recentTransfers(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"data": h.deps.Store.ListRecentTransfers(20)})
}
func (h *HTTPHandler) requireAdmin() gin.HandlerFunc {
return func(c *gin.Context) {
header := c.GetHeader("Authorization")
if header == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
return
}
token := strings.TrimSpace(strings.TrimPrefix(header, "Bearer"))
if token == "" || !h.deps.AdminService.ValidateToken(token) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid admin token"})
return
}
c.Next()
}
}
func (h *HTTPHandler) requireDevice() gin.HandlerFunc {
return func(c *gin.Context) {
deviceID := strings.TrimSpace(c.GetHeader("X-Device-ID"))
token := strings.TrimSpace(c.GetHeader("X-Device-Token"))
if deviceID == "" || token == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing device credentials"})
return
}
if !h.deps.DeviceService.ValidateSession(deviceID, token) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid device credentials"})
return
}
c.Set("device_id", deviceID)
c.Next()
}
}
func (h *HTTPHandler) authenticatedDeviceID(c *gin.Context) string {
value, ok := c.Get("device_id")
if !ok {
return ""
}
deviceID, _ := value.(string)
return deviceID
}
func (h *HTTPHandler) ensureAuthenticatedDevice(c *gin.Context, expected string) bool {
if h.authenticatedDeviceID(c) != strings.TrimSpace(expected) {
c.JSON(http.StatusForbidden, gin.H{"error": "device access denied"})
return false
}
return true
}
func fallbackDownloadPath(transferID string) string {
return "/api/transfers/" + url.PathEscape(transferID) + "/fallback/download"
}
func contentDisposition(filename string) string {
escaped := strings.ReplaceAll(filename, `"`, "")
return `attachment; filename="` + escaped + `"`
}
func (h *HTTPHandler) ensureFallbackBucket(ctx context.Context, transferID string) error {
if h.deps.MinIOClient == nil {
return nil
}
err := h.deps.MinIOClient.EnsureBucket(ctx)
if err == nil {
return nil
}
// Some MinIO users can read/write objects in an existing bucket but cannot run
// bucket-existence or bucket-creation checks. In that case we keep going.
if strings.Contains(strings.ToLower(err.Error()), "access denied") {
h.deps.Logger.Info("minio bucket ensure skipped due to limited permissions", "transfer_id", transferID)
return nil
}
return err
}

View File

@@ -0,0 +1,195 @@
package handler
import (
"bytes"
"encoding/json"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"filefast/backend/internal/config"
"filefast/backend/internal/model"
"filefast/backend/internal/service"
"filefast/backend/internal/store"
"filefast/backend/internal/ws"
)
type registeredDevice struct {
ID string `json:"id"`
AuthToken string `json:"auth_token"`
}
type transferRecord struct {
ID string `json:"id"`
}
func TestProtectedRoutesRequireDeviceCredentials(t *testing.T) {
router, _ := newTestRouter()
device := registerDevice(t, router, map[string]any{
"device_id": "alpha",
"name": "Alpha",
"type": "desktop",
})
req := httptest.NewRequest(http.MethodGet, "/api/devices/candidates?deviceId="+device.ID, nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusUnauthorized {
t.Fatalf("expected 401 for missing device credentials, got %d", resp.Code)
}
}
func TestProtectedRoutesRejectMismatchedDeviceIdentity(t *testing.T) {
router, _ := newTestRouter()
alpha := registerDevice(t, router, map[string]any{
"device_id": "alpha",
"name": "Alpha",
"type": "desktop",
})
bravo := registerDevice(t, router, map[string]any{
"device_id": "bravo",
"name": "Bravo",
"type": "desktop",
})
req := httptest.NewRequest(http.MethodGet, "/api/devices/candidates?deviceId="+bravo.ID, nil)
req.Header.Set("X-Device-ID", alpha.ID)
req.Header.Set("X-Device-Token", alpha.AuthToken)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusForbidden {
t.Fatalf("expected 403 for mismatched device identity, got %d", resp.Code)
}
}
func TestTransferStatusUpdateRequiresParticipantOwnership(t *testing.T) {
router, _ := newTestRouter()
sender := registerDevice(t, router, map[string]any{
"device_id": "sender",
"name": "Sender",
"type": "desktop",
})
receiver := registerDevice(t, router, map[string]any{
"device_id": "receiver",
"name": "Receiver",
"type": "desktop",
})
attacker := registerDevice(t, router, map[string]any{
"device_id": "attacker",
"name": "Attacker",
"type": "desktop",
})
transfer := createTransfer(t, router, sender, map[string]any{
"kind": "text",
"name": "text-message",
"content": "hello",
"sender_device_id": sender.ID,
"receiver_device_id": receiver.ID,
})
body, err := json.Marshal(map[string]any{
"final_status": "completed",
})
if err != nil {
t.Fatalf("marshal update status body: %v", err)
}
req := httptest.NewRequest(http.MethodPatch, "/api/transfers/"+transfer.ID+"/status", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Device-ID", attacker.ID)
req.Header.Set("X-Device-Token", attacker.AuthToken)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusForbidden {
t.Fatalf("expected 403 for non-participant transfer update, got %d", resp.Code)
}
}
func newTestRouter() (http.Handler, *store.MemoryStore) {
memStore := store.NewMemoryStore(model.RuntimeConfig{})
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
deviceService := service.NewDeviceService(memStore, nil, nil)
deps := Dependencies{
Config: config.Config{},
Logger: logger,
Store: memStore,
DeviceService: deviceService,
RoomService: service.NewRoomService(memStore, 0),
TransferService: service.NewTransferService(memStore),
AdminService: service.NewAdminService(memStore, config.AdminConfig{}, nil, nil),
Hub: ws.NewHub(logger, deviceService, nil),
StorageReady: true,
RedisReady: true,
}
return NewHTTPHandler(deps).Router(), memStore
}
func registerDevice(t *testing.T, router http.Handler, payload map[string]any) registeredDevice {
t.Helper()
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal register body: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/api/devices/register", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected register 200, got %d: %s", resp.Code, resp.Body.String())
}
var payloadWrapper struct {
Data registeredDevice `json:"data"`
}
if err := json.Unmarshal(resp.Body.Bytes(), &payloadWrapper); err != nil {
t.Fatalf("decode register response: %v", err)
}
if payloadWrapper.Data.ID == "" || payloadWrapper.Data.AuthToken == "" {
t.Fatalf("expected device registration to return id and auth token, got %+v", payloadWrapper.Data)
}
return payloadWrapper.Data
}
func createTransfer(t *testing.T, router http.Handler, device registeredDevice, payload map[string]any) transferRecord {
t.Helper()
body, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal create transfer body: %v", err)
}
req := httptest.NewRequest(http.MethodPost, "/api/transfers", bytes.NewReader(body))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Device-ID", device.ID)
req.Header.Set("X-Device-Token", device.AuthToken)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("expected create transfer 200, got %d: %s", resp.Code, resp.Body.String())
}
var payloadWrapper struct {
Data transferRecord `json:"data"`
}
if err := json.Unmarshal(resp.Body.Bytes(), &payloadWrapper); err != nil {
t.Fatalf("decode transfer response: %v", err)
}
if payloadWrapper.Data.ID == "" {
t.Fatalf("expected transfer id in response, got %+v", payloadWrapper.Data)
}
return payloadWrapper.Data
}

View File

@@ -0,0 +1,125 @@
package model
import "time"
const (
RoomStatusWaiting = "waiting"
RoomStatusJoined = "joined"
RoomStatusCanceled = "canceled"
RoomStatusExpired = "expired"
ChannelP2P = "p2p"
ChannelTURN = "turn"
ChannelMinIO = "minio"
TransferPending = "pending"
TransferConnecting = "connecting"
TransferP2PTransferring = "p2p_transferring"
TransferTURNRelaying = "turn_relaying"
TransferFallbackUploading = "fallback_uploading"
TransferCompleted = "completed"
TransferFailed = "failed"
TransferCancelled = "cancelled"
)
type RuntimeConfig struct {
MaxMinIOFallbackSizeBytes int64 `json:"max_minio_fallback_size_bytes"`
MinIOCapacityBytes int64 `json:"minio_capacity_bytes"`
MinIORetentionHours int `json:"minio_retention_hours"`
MinIOUsageAlertPercent int `json:"minio_usage_alert_percent"`
P2PConnectTimeoutSec int `json:"p2p_connect_timeout_sec"`
TURNConnectTimeoutSec int `json:"turn_connect_timeout_sec"`
MinIOFallbackEnabled bool `json:"minio_fallback_enabled"`
TURNURLs []string `json:"turn_urls"`
TURNUsername string `json:"turn_username"`
TURNPassword string `json:"turn_password"`
}
type Device struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type"`
UserAgent string `json:"user_agent,omitempty"`
NetworkGroupKey string `json:"network_group_key,omitempty"`
PublicIPHash string `json:"public_ip_hash,omitempty"`
IsOnline bool `json:"is_online"`
LastSeenAt time.Time `json:"last_seen_at"`
CreatedAt time.Time `json:"created_at"`
}
type Room struct {
Code string `json:"code"`
CreatorDeviceID string `json:"creator_device_id"`
JoinerDeviceID string `json:"joiner_device_id,omitempty"`
Status string `json:"status"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type Transfer struct {
ID string `json:"id"`
SessionID string `json:"session_id,omitempty"`
Kind string `json:"kind"`
Name string `json:"name"`
Content string `json:"content,omitempty"`
SizeBytes int64 `json:"size_bytes"`
SenderDeviceID string `json:"sender_device_id"`
ReceiverDeviceID string `json:"receiver_device_id"`
TransferStrategy string `json:"transfer_strategy"`
CurrentChannel string `json:"current_channel"`
FallbackAllowed bool `json:"fallback_allowed"`
FinalStatus string `json:"final_status"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
FallbackReason string `json:"fallback_reason,omitempty"`
ObjectKey string `json:"object_key,omitempty"`
ExpiresAt *time.Time `json:"expires_at,omitempty"`
}
type FallbackObject struct {
TransferID string `json:"transfer_id"`
ObjectKey string `json:"object_key"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
CleanedAt *time.Time `json:"cleaned_at,omitempty"`
CleanupState string `json:"cleanup_state"`
}
type PendingFallbackDownload struct {
TransferID string `json:"transfer_id"`
Name string `json:"name"`
SizeBytes int64 `json:"size_bytes"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
DownloadPath string `json:"download_path"`
SenderDeviceID string `json:"sender_device_id"`
}
type MinIOStorageOverview struct {
Enabled bool `json:"enabled"`
UsedBytes int64 `json:"used_bytes"`
CapacityBytes int64 `json:"capacity_bytes"`
RemainingBytes int64 `json:"remaining_bytes"`
UsagePercent int `json:"usage_percent"`
ObjectCount int `json:"object_count"`
}
type AdminSession struct {
Token string `json:"token"`
CreatedAt time.Time `json:"created_at"`
}
type DeviceSession struct {
DeviceID string `json:"device_id"`
Token string `json:"token"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
}
type SignalEnvelope struct {
Type string `json:"type"`
DeviceID string `json:"device_id,omitempty"`
TargetDeviceID string `json:"target_device_id,omitempty"`
Payload interface{} `json:"payload,omitempty"`
}

View File

@@ -0,0 +1,59 @@
package scheduler
import (
"context"
"log/slog"
"time"
"filefast/backend/internal/storage"
"filefast/backend/internal/store"
"github.com/robfig/cron/v3"
)
type CleanupScheduler struct {
logger *slog.Logger
store *store.MemoryStore
minioClient *storage.MinIOClient
cron *cron.Cron
}
func NewCleanupScheduler(logger *slog.Logger, store *store.MemoryStore, minioClient *storage.MinIOClient) *CleanupScheduler {
return &CleanupScheduler{
logger: logger,
store: store,
minioClient: minioClient,
cron: cron.New(),
}
}
func (s *CleanupScheduler) Start() {
_, err := s.cron.AddFunc("@daily", s.cleanupExpiredFallbacks)
if err != nil {
s.logger.Error("failed to register cleanup cron", "error", err)
return
}
s.cron.Start()
}
func (s *CleanupScheduler) Stop() {
ctx := s.cron.Stop()
<-ctx.Done()
}
func (s *CleanupScheduler) cleanupExpiredFallbacks() {
now := time.Now()
for _, object := range s.store.ListExpiredFallbackObjects(now) {
if s.minioClient != nil {
if err := s.minioClient.RemoveObject(context.Background(), object.ObjectKey); err != nil {
s.logger.Warn("failed to remove expired fallback object", "transfer_id", object.TransferID, "error", err)
continue
}
}
object.CleanupState = "cleaned"
object.CleanedAt = &now
s.store.SaveFallbackObject(object)
s.logger.Info("cleaned expired fallback object", "transfer_id", object.TransferID, "object_key", object.ObjectKey)
}
}

View File

@@ -0,0 +1,95 @@
package service
import (
"context"
"errors"
"time"
"filefast/backend/internal/config"
"filefast/backend/internal/model"
"filefast/backend/internal/store"
"github.com/google/uuid"
)
type AdminService struct {
store *store.MemoryStore
config config.AdminConfig
sessionStore adminSessionStore
authStore adminCredentialStore
sessionTTL time.Duration
}
type adminSessionStore interface {
SaveAdminSession(context.Context, model.AdminSession, time.Duration) error
HasAdminSession(context.Context, string) (bool, error)
}
type adminCredentialStore interface {
ValidateAdminCredentials(context.Context, string, string) (bool, error)
}
func NewAdminService(store *store.MemoryStore, cfg config.AdminConfig, sessionStore adminSessionStore, authStore adminCredentialStore) *AdminService {
return &AdminService{
store: store,
config: cfg,
sessionStore: sessionStore,
authStore: authStore,
sessionTTL: 24 * time.Hour,
}
}
func (s *AdminService) Login(username, password string) (model.AdminSession, error) {
if s.authStore != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
ok, err := s.authStore.ValidateAdminCredentials(ctx, username, password)
if err == nil {
if !ok {
return model.AdminSession{}, errors.New("invalid admin credentials")
}
return s.issueSession()
}
}
if username != s.config.Username || password != s.config.Password {
return model.AdminSession{}, errors.New("invalid admin credentials")
}
return s.issueSession()
}
func (s *AdminService) issueSession() (model.AdminSession, error) {
session := model.AdminSession{
Token: uuid.NewString(),
CreatedAt: time.Now(),
}
session = s.store.SaveAdminSession(session)
if s.sessionStore != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := s.sessionStore.SaveAdminSession(ctx, session, s.sessionTTL); err != nil {
return model.AdminSession{}, err
}
}
return session, nil
}
func (s *AdminService) ValidateToken(token string) bool {
if s.store.HasAdminSession(token) {
return true
}
if s.sessionStore == nil {
return false
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
ok, err := s.sessionStore.HasAdminSession(ctx, token)
return err == nil && ok
}

View File

@@ -0,0 +1,210 @@
package service
import (
"context"
"sort"
"strings"
"time"
"filefast/backend/internal/model"
"filefast/backend/internal/store"
"github.com/google/uuid"
)
type DeviceService struct {
store *store.MemoryStore
presenceStore devicePresenceStore
sessionStore deviceSessionStore
presenceTTL time.Duration
sessionTTL time.Duration
}
type devicePresenceStore interface {
SetDevicePresence(context.Context, string, bool, time.Time, time.Duration) error
GetDevicePresence(context.Context, []string) (map[string]bool, error)
}
type deviceSessionStore interface {
SaveDeviceSession(context.Context, model.DeviceSession, time.Duration) error
ValidateDeviceSession(context.Context, string, string) (bool, error)
}
type RegisterDeviceInput struct {
DeviceID string `json:"device_id"`
Name string `json:"name" binding:"required"`
Type string `json:"type" binding:"required"`
NetworkGroupKey string `json:"network_group_key"`
PublicIPHash string `json:"public_ip_hash"`
}
func NewDeviceService(store *store.MemoryStore, presenceStore devicePresenceStore, sessionStore deviceSessionStore) *DeviceService {
return &DeviceService{
store: store,
presenceStore: presenceStore,
sessionStore: sessionStore,
presenceTTL: 45 * time.Second,
sessionTTL: 30 * 24 * time.Hour,
}
}
func (s *DeviceService) Register(input RegisterDeviceInput, userAgent, claimedToken string) (model.Device, model.DeviceSession) {
now := time.Now()
id := strings.TrimSpace(input.DeviceID)
if id == "" {
id = uuid.NewString()
}
device, exists := s.store.GetDevice(id)
if exists && !s.ValidateSession(id, strings.TrimSpace(claimedToken)) {
id = uuid.NewString()
exists = false
}
if !exists {
device = model.Device{
ID: id,
CreatedAt: now,
}
}
device.Name = input.Name
device.Type = input.Type
device.UserAgent = userAgent
device.NetworkGroupKey = input.NetworkGroupKey
device.PublicIPHash = input.PublicIPHash
device.LastSeenAt = now
device.IsOnline = true
device = s.store.UpsertDevice(device)
session := s.issueSession(device.ID)
s.syncPresence(device)
return device, session
}
func (s *DeviceService) Heartbeat(deviceID string) (model.Device, bool) {
device, ok := s.store.GetDevice(deviceID)
if !ok {
return model.Device{}, false
}
device.LastSeenAt = time.Now()
device.IsOnline = true
device = s.store.UpsertDevice(device)
s.syncPresence(device)
return device, true
}
func (s *DeviceService) SetOnline(deviceID string, online bool) (model.Device, bool) {
device, ok := s.store.GetDevice(deviceID)
if !ok {
return model.Device{}, false
}
device.IsOnline = online
device.LastSeenAt = time.Now()
device = s.store.UpsertDevice(device)
s.syncPresence(device)
return device, true
}
func (s *DeviceService) ListCandidates(currentDeviceID string) []model.Device {
current, _ := s.store.GetDevice(currentDeviceID)
devices := s.store.ListDevices()
s.applyPresence(devices)
candidates := make([]model.Device, 0, len(devices))
for _, device := range devices {
if device.ID == currentDeviceID || !device.IsOnline {
continue
}
candidates = append(candidates, device)
}
sort.SliceStable(candidates, func(i, j int) bool {
leftSameNetwork := current.NetworkGroupKey != "" && candidates[i].NetworkGroupKey == current.NetworkGroupKey
rightSameNetwork := current.NetworkGroupKey != "" && candidates[j].NetworkGroupKey == current.NetworkGroupKey
if leftSameNetwork != rightSameNetwork {
return leftSameNetwork
}
return candidates[i].LastSeenAt.After(candidates[j].LastSeenAt)
})
return candidates
}
func (s *DeviceService) ValidateSession(deviceID, token string) bool {
deviceID = strings.TrimSpace(deviceID)
token = strings.TrimSpace(token)
if deviceID == "" || token == "" {
return false
}
if s.store.ValidateDeviceSession(deviceID, token) {
return true
}
if s.sessionStore == nil {
return false
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
ok, err := s.sessionStore.ValidateDeviceSession(ctx, deviceID, token)
return err == nil && ok
}
func (s *DeviceService) issueSession(deviceID string) model.DeviceSession {
session := model.DeviceSession{
DeviceID: deviceID,
Token: uuid.NewString(),
CreatedAt: time.Now(),
ExpiresAt: time.Now().Add(s.sessionTTL),
}
session = s.store.SaveDeviceSession(session)
if s.sessionStore != nil {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = s.sessionStore.SaveDeviceSession(ctx, session, s.sessionTTL)
}
return session
}
func (s *DeviceService) syncPresence(device model.Device) {
if s.presenceStore == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = s.presenceStore.SetDevicePresence(ctx, device.ID, device.IsOnline, device.LastSeenAt, s.presenceTTL)
}
func (s *DeviceService) applyPresence(devices []model.Device) {
if s.presenceStore == nil || len(devices) == 0 {
return
}
deviceIDs := make([]string, 0, len(devices))
for _, device := range devices {
if device.ID != "" {
deviceIDs = append(deviceIDs, device.ID)
}
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
statuses, err := s.presenceStore.GetDevicePresence(ctx, deviceIDs)
if err != nil {
return
}
for index := range devices {
if online, ok := statuses[devices[index].ID]; ok {
devices[index].IsOnline = online
}
}
}

View File

@@ -0,0 +1,58 @@
package service
import (
"testing"
"filefast/backend/internal/model"
"filefast/backend/internal/store"
)
func TestRegisterReusesKnownDeviceOnlyWithValidToken(t *testing.T) {
memStore := store.NewMemoryStore(model.RuntimeConfig{})
deviceService := NewDeviceService(memStore, nil, nil)
device, session := deviceService.Register(RegisterDeviceInput{
DeviceID: "known-device",
Name: "Alpha",
Type: "desktop",
}, "ua/1.0", "")
if device.ID != "known-device" {
t.Fatalf("expected first registration to keep requested device id, got %q", device.ID)
}
if !deviceService.ValidateSession(device.ID, session.Token) {
t.Fatal("expected issued device token to validate")
}
hijacked, hijackedSession := deviceService.Register(RegisterDeviceInput{
DeviceID: "known-device",
Name: "Mallory",
Type: "desktop",
}, "ua/1.0", "")
if hijacked.ID == device.ID {
t.Fatal("expected registration without token to receive a new device id")
}
if !deviceService.ValidateSession(hijacked.ID, hijackedSession.Token) {
t.Fatal("expected replacement device token to validate")
}
restored, rotatedSession := deviceService.Register(RegisterDeviceInput{
DeviceID: "known-device",
Name: "Alpha",
Type: "desktop",
}, "ua/1.0", session.Token)
if restored.ID != device.ID {
t.Fatalf("expected valid token to reclaim original device id, got %q", restored.ID)
}
if rotatedSession.Token == session.Token {
t.Fatal("expected registration to rotate the device token")
}
if deviceService.ValidateSession(restored.ID, session.Token) {
t.Fatal("expected rotated token to invalidate the old token")
}
if !deviceService.ValidateSession(restored.ID, rotatedSession.Token) {
t.Fatal("expected rotated device token to validate")
}
}

View File

@@ -0,0 +1,84 @@
package service
import (
"errors"
"fmt"
"math/rand/v2"
"time"
"filefast/backend/internal/model"
"filefast/backend/internal/store"
)
type RoomService struct {
store *store.MemoryStore
ttl time.Duration
}
func NewRoomService(store *store.MemoryStore, ttl time.Duration) *RoomService {
return &RoomService{
store: store,
ttl: ttl,
}
}
func (s *RoomService) CreateRoom(creatorDeviceID string) (model.Room, error) {
if creatorDeviceID == "" {
return model.Room{}, errors.New("creator_device_id is required")
}
now := time.Now()
for range 20 {
code := fmt.Sprintf("%04d", rand.IntN(10000))
if room, ok := s.store.GetRoom(code); ok && room.ExpiresAt.After(now) && room.Status != model.RoomStatusExpired {
continue
}
room := model.Room{
Code: code,
CreatorDeviceID: creatorDeviceID,
Status: model.RoomStatusWaiting,
CreatedAt: now,
ExpiresAt: now.Add(s.ttl),
}
return s.store.UpsertRoom(room), nil
}
return model.Room{}, errors.New("failed to allocate room code")
}
func (s *RoomService) JoinRoom(code, joinerDeviceID string) (model.Room, error) {
room, ok := s.store.GetRoom(code)
if !ok {
return model.Room{}, errors.New("room not found")
}
now := time.Now()
if !room.ExpiresAt.After(now) {
room.Status = model.RoomStatusExpired
s.store.UpsertRoom(room)
return model.Room{}, errors.New("room expired")
}
if room.Status != model.RoomStatusWaiting {
return model.Room{}, errors.New("room unavailable")
}
room.JoinerDeviceID = joinerDeviceID
room.Status = model.RoomStatusJoined
return s.store.UpsertRoom(room), nil
}
func (s *RoomService) CancelRoom(code, requesterID string) (model.Room, error) {
room, ok := s.store.GetRoom(code)
if !ok {
return model.Room{}, errors.New("room not found")
}
if room.CreatorDeviceID != requesterID {
return model.Room{}, errors.New("only creator can cancel room")
}
room.Status = model.RoomStatusCanceled
return s.store.UpsertRoom(room), nil
}

View File

@@ -0,0 +1,138 @@
package service
import (
"errors"
"fmt"
"time"
"filefast/backend/internal/model"
"filefast/backend/internal/store"
"github.com/google/uuid"
)
type TransferService struct {
store *store.MemoryStore
}
type CreateTransferInput struct {
SessionID string `json:"session_id"`
Kind string `json:"kind" binding:"required"`
Name string `json:"name"`
Content string `json:"content"`
SizeBytes int64 `json:"size_bytes"`
SenderDeviceID string `json:"sender_device_id" binding:"required"`
ReceiverDeviceID string `json:"receiver_device_id" binding:"required"`
}
type UpdateTransferStatusInput struct {
CurrentChannel string `json:"current_channel"`
FinalStatus string `json:"final_status"`
FallbackReason string `json:"fallback_reason"`
}
func NewTransferService(store *store.MemoryStore) *TransferService {
return &TransferService{store: store}
}
func (s *TransferService) Create(input CreateTransferInput) (model.Transfer, error) {
switch input.Kind {
case "file":
if input.Name == "" {
return model.Transfer{}, errors.New("file name is required")
}
if input.SizeBytes <= 0 {
return model.Transfer{}, errors.New("file size must be greater than zero")
}
case "text":
if input.Content == "" {
return model.Transfer{}, errors.New("text content is required")
}
if input.Name == "" {
input.Name = "text-message"
}
default:
return model.Transfer{}, errors.New("unsupported transfer kind")
}
runtime := s.store.RuntimeConfig()
now := time.Now()
fallbackAllowed := input.Kind == "file" && runtime.MinIOFallbackEnabled
strategy := "p2p_turn"
if fallbackAllowed {
strategy = "p2p_turn_minio"
}
transfer := model.Transfer{
ID: uuid.NewString(),
SessionID: input.SessionID,
Kind: input.Kind,
Name: input.Name,
Content: input.Content,
SizeBytes: input.SizeBytes,
SenderDeviceID: input.SenderDeviceID,
ReceiverDeviceID: input.ReceiverDeviceID,
TransferStrategy: strategy,
CurrentChannel: model.ChannelP2P,
FallbackAllowed: fallbackAllowed,
FinalStatus: model.TransferPending,
CreatedAt: now,
UpdatedAt: now,
}
return s.store.UpsertTransfer(transfer), nil
}
func (s *TransferService) UpdateStatus(transferID string, input UpdateTransferStatusInput) (model.Transfer, error) {
transfer, ok := s.store.GetTransfer(transferID)
if !ok {
return model.Transfer{}, errors.New("transfer not found")
}
if input.CurrentChannel != "" {
transfer.CurrentChannel = input.CurrentChannel
}
if input.FinalStatus != "" {
transfer.FinalStatus = input.FinalStatus
}
if input.FallbackReason != "" {
transfer.FallbackReason = input.FallbackReason
}
transfer.UpdatedAt = time.Now()
return s.store.UpsertTransfer(transfer), nil
}
func (s *TransferService) PrepareFallback(transferID string) (model.Transfer, model.FallbackObject, error) {
transfer, ok := s.store.GetTransfer(transferID)
if !ok {
return model.Transfer{}, model.FallbackObject{}, errors.New("transfer not found")
}
if !transfer.FallbackAllowed {
return model.Transfer{}, model.FallbackObject{}, errors.New("transfer cannot use minio fallback")
}
if object, ok := s.store.GetFallbackObject(transfer.ID); ok && object.CleanedAt == nil && object.ExpiresAt.After(time.Now()) {
return transfer, object, nil
}
runtime := s.store.RuntimeConfig()
now := time.Now()
expireAt := now.Add(time.Duration(runtime.MinIORetentionHours) * time.Hour)
objectKey := fmt.Sprintf("fallback/%s/%d-%s", now.Format("20060102"), now.Unix(), transfer.ID)
transfer.ObjectKey = objectKey
transfer.ExpiresAt = &expireAt
transfer.UpdatedAt = now
transfer = s.store.UpsertTransfer(transfer)
object := model.FallbackObject{
TransferID: transfer.ID,
ObjectKey: objectKey,
SizeBytes: transfer.SizeBytes,
CreatedAt: now,
ExpiresAt: expireAt,
CleanupState: "uploading",
}
object = s.store.SaveFallbackObject(object)
return transfer, object, nil
}

View File

@@ -0,0 +1,141 @@
package storage
import (
"context"
"errors"
"io"
"net/url"
"strings"
"time"
"filefast/backend/internal/config"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
)
type MinIOClient struct {
client *minio.Client
bucket string
presignExpiry time.Duration
enabled bool
}
func NewMinIOClient(cfg config.MinIOConfig) (*MinIOClient, error) {
if cfg.Endpoint == "" || cfg.AccessKey == "" || cfg.SecretKey == "" {
return &MinIOClient{
bucket: cfg.Bucket,
presignExpiry: cfg.PresignExpiry,
enabled: false,
}, nil
}
client, err := minio.New(cfg.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
Secure: cfg.UseSSL,
})
if err != nil {
return nil, err
}
return &MinIOClient{
client: client,
bucket: cfg.Bucket,
presignExpiry: cfg.PresignExpiry,
enabled: true,
}, nil
}
func (c *MinIOClient) Enabled() bool {
return c != nil && c.enabled
}
func (c *MinIOClient) EnsureBucket(ctx context.Context) error {
if !c.Enabled() {
return nil
}
exists, err := c.client.BucketExists(ctx, c.bucket)
if err != nil {
return err
}
if exists {
return nil
}
return c.client.MakeBucket(ctx, c.bucket, minio.MakeBucketOptions{})
}
func (c *MinIOClient) PresignUpload(ctx context.Context, objectKey string) (*url.URL, error) {
if !c.Enabled() {
return nil, errors.New("minio is disabled")
}
return c.client.PresignedPutObject(ctx, c.bucket, objectKey, c.presignExpiry)
}
func (c *MinIOClient) PresignDownload(ctx context.Context, objectKey string) (*url.URL, error) {
return c.PresignDownloadWithFilename(ctx, objectKey, "")
}
func (c *MinIOClient) PresignDownloadWithFilename(ctx context.Context, objectKey, filename string) (*url.URL, error) {
if !c.Enabled() {
return nil, errors.New("minio is disabled")
}
var reqParams url.Values
if filename = sanitizeDownloadFilename(filename); filename != "" {
reqParams = make(url.Values)
reqParams.Set("response-content-disposition", `attachment; filename="`+filename+`"`)
}
return c.client.PresignedGetObject(ctx, c.bucket, objectKey, c.presignExpiry, reqParams)
}
func (c *MinIOClient) UploadObject(ctx context.Context, objectKey string, reader io.Reader, size int64, contentType string) error {
if !c.Enabled() {
return errors.New("minio is disabled")
}
options := minio.PutObjectOptions{
ContentType: contentType,
}
_, err := c.client.PutObject(ctx, c.bucket, objectKey, reader, size, options)
return err
}
func (c *MinIOClient) OpenObject(ctx context.Context, objectKey string) (*minio.Object, minio.ObjectInfo, error) {
if !c.Enabled() {
return nil, minio.ObjectInfo{}, errors.New("minio is disabled")
}
object, err := c.client.GetObject(ctx, c.bucket, objectKey, minio.GetObjectOptions{})
if err != nil {
return nil, minio.ObjectInfo{}, err
}
info, err := object.Stat()
if err != nil {
_ = object.Close()
return nil, minio.ObjectInfo{}, err
}
return object, info, nil
}
func (c *MinIOClient) RemoveObject(ctx context.Context, objectKey string) error {
if !c.Enabled() {
return nil
}
return c.client.RemoveObject(ctx, c.bucket, objectKey, minio.RemoveObjectOptions{})
}
func sanitizeDownloadFilename(filename string) string {
filename = strings.TrimSpace(filename)
if filename == "" {
return ""
}
filename = strings.ReplaceAll(filename, `"`, "")
filename = strings.ReplaceAll(filename, "\r", "")
filename = strings.ReplaceAll(filename, "\n", "")
return filename
}

View File

@@ -0,0 +1,205 @@
package storage
import (
"context"
"encoding/json"
"fmt"
"time"
"filefast/backend/internal/config"
"filefast/backend/internal/model"
"github.com/redis/go-redis/v9"
)
const (
RedisRealtimeRelayChannel = "filefast:ws:relay"
RedisRealtimePresenceChannel = "filefast:ws:presence"
)
type RedisClient struct {
client *redis.Client
}
type RealtimeMessage struct {
Channel string
Payload []byte
}
func NewRedisClient(cfg config.RedisConfig) *RedisClient {
client := redis.NewClient(&redis.Options{
Addr: cfg.Addr,
Password: cfg.Password,
DB: cfg.DB,
MaxRetries: 1,
DialTimeout: 3 * time.Second,
ReadTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
})
return &RedisClient{client: client}
}
func (c *RedisClient) Ping(ctx context.Context) error {
if !c.available() {
return redis.ErrClosed
}
return c.client.Ping(ctx).Err()
}
func (c *RedisClient) SaveAdminSession(ctx context.Context, session model.AdminSession, ttl time.Duration) error {
if !c.available() {
return redis.ErrClosed
}
payload, err := json.Marshal(session)
if err != nil {
return err
}
return c.client.Set(ctx, adminSessionKey(session.Token), payload, ttl).Err()
}
func (c *RedisClient) HasAdminSession(ctx context.Context, token string) (bool, error) {
if !c.available() {
return false, redis.ErrClosed
}
result, err := c.client.Exists(ctx, adminSessionKey(token)).Result()
if err != nil {
return false, err
}
return result > 0, nil
}
func (c *RedisClient) SaveDeviceSession(ctx context.Context, session model.DeviceSession, ttl time.Duration) error {
if !c.available() {
return redis.ErrClosed
}
return c.client.Set(ctx, deviceSessionKey(session.DeviceID), session.Token, ttl).Err()
}
func (c *RedisClient) ValidateDeviceSession(ctx context.Context, deviceID, token string) (bool, error) {
if !c.available() {
return false, redis.ErrClosed
}
value, err := c.client.Get(ctx, deviceSessionKey(deviceID)).Result()
if err != nil {
if err == redis.Nil {
return false, nil
}
return false, err
}
return value == token, nil
}
func (c *RedisClient) SetDevicePresence(ctx context.Context, deviceID string, online bool, lastSeen time.Time, ttl time.Duration) error {
if !c.available() {
return redis.ErrClosed
}
key := devicePresenceKey(deviceID)
if !online {
return c.client.Del(ctx, key).Err()
}
return c.client.Set(ctx, key, lastSeen.UTC().Format(time.RFC3339Nano), ttl).Err()
}
func (c *RedisClient) GetDevicePresence(ctx context.Context, deviceIDs []string) (map[string]bool, error) {
if !c.available() {
return nil, redis.ErrClosed
}
statuses := make(map[string]bool, len(deviceIDs))
if len(deviceIDs) == 0 {
return statuses, nil
}
keys := make([]string, 0, len(deviceIDs))
for _, id := range deviceIDs {
keys = append(keys, devicePresenceKey(id))
}
values, err := c.client.MGet(ctx, keys...).Result()
if err != nil {
return nil, err
}
for index, id := range deviceIDs {
statuses[id] = values[index] != nil
}
return statuses, nil
}
func (c *RedisClient) PublishRealtime(ctx context.Context, channel string, payload []byte) error {
if !c.available() {
return redis.ErrClosed
}
return c.client.Publish(ctx, channel, payload).Err()
}
func (c *RedisClient) SubscribeRealtime(ctx context.Context, channels ...string) (<-chan RealtimeMessage, error) {
if !c.available() {
return nil, redis.ErrClosed
}
pubsub := c.client.Subscribe(ctx, channels...)
if _, err := pubsub.Receive(ctx); err != nil {
_ = pubsub.Close()
return nil, err
}
source := pubsub.Channel()
messages := make(chan RealtimeMessage, 64)
go func() {
defer close(messages)
defer pubsub.Close()
for {
select {
case <-ctx.Done():
return
case message, ok := <-source:
if !ok {
return
}
messages <- RealtimeMessage{
Channel: message.Channel,
Payload: []byte(message.Payload),
}
}
}
}()
return messages, nil
}
func (c *RedisClient) Close() error {
if !c.available() {
return nil
}
return c.client.Close()
}
func (c *RedisClient) available() bool {
return c != nil && c.client != nil
}
func adminSessionKey(token string) string {
return fmt.Sprintf("filefast:admin:session:%s", token)
}
func devicePresenceKey(deviceID string) string {
return fmt.Sprintf("filefast:device:online:%s", deviceID)
}
func deviceSessionKey(deviceID string) string {
return fmt.Sprintf("filefast:device:session:%s", deviceID)
}

View File

@@ -0,0 +1,125 @@
package storage
import (
"encoding/json"
"errors"
"os"
"path/filepath"
"filefast/backend/internal/model"
)
const (
runtimeConfigKey = "transfer_policy"
)
type runtimeConfigPayload struct {
MaxMinIOFallbackSizeBytes *int64 `json:"max_minio_fallback_size_bytes"`
MaxMinIOFallbackGB *int64 `json:"max_minio_fallback_gb"`
MinIOCapacityBytes *int64 `json:"minio_capacity_bytes"`
MinIOCapacityGB *int64 `json:"minio_capacity_gb"`
MinIORetentionHours *int `json:"minio_retention_hours"`
MinIOUsageAlertPercent *int `json:"minio_usage_alert_percent"`
P2PConnectTimeoutSec *int `json:"p2p_connect_timeout_sec"`
TURNConnectTimeoutSec *int `json:"turn_connect_timeout_sec"`
MinIOFallbackEnabled *bool `json:"minio_fallback_enabled"`
TURNURLs []string `json:"turn_urls"`
TURNUsername *string `json:"turn_username"`
TURNPassword *string `json:"turn_password"`
}
func decodeRuntimeConfig(raw []byte, fallback model.RuntimeConfig) (model.RuntimeConfig, error) {
cfg := fallback
if len(raw) == 0 {
return cfg, nil
}
var payload runtimeConfigPayload
if err := json.Unmarshal(raw, &payload); err != nil {
return model.RuntimeConfig{}, err
}
if payload.MaxMinIOFallbackSizeBytes != nil {
cfg.MaxMinIOFallbackSizeBytes = *payload.MaxMinIOFallbackSizeBytes
} else if payload.MaxMinIOFallbackGB != nil {
cfg.MaxMinIOFallbackSizeBytes = *payload.MaxMinIOFallbackGB * 1024 * 1024 * 1024
}
if payload.MinIOCapacityBytes != nil {
cfg.MinIOCapacityBytes = *payload.MinIOCapacityBytes
} else if payload.MinIOCapacityGB != nil {
cfg.MinIOCapacityBytes = *payload.MinIOCapacityGB * 1024 * 1024 * 1024
}
if payload.MinIORetentionHours != nil {
cfg.MinIORetentionHours = *payload.MinIORetentionHours
}
if payload.MinIOUsageAlertPercent != nil {
cfg.MinIOUsageAlertPercent = *payload.MinIOUsageAlertPercent
}
if payload.P2PConnectTimeoutSec != nil {
cfg.P2PConnectTimeoutSec = *payload.P2PConnectTimeoutSec
}
if payload.TURNConnectTimeoutSec != nil {
cfg.TURNConnectTimeoutSec = *payload.TURNConnectTimeoutSec
}
if payload.MinIOFallbackEnabled != nil {
cfg.MinIOFallbackEnabled = *payload.MinIOFallbackEnabled
}
if payload.TURNURLs != nil {
cfg.TURNURLs = payload.TURNURLs
}
if payload.TURNUsername != nil {
cfg.TURNUsername = *payload.TURNUsername
}
if payload.TURNPassword != nil {
cfg.TURNPassword = *payload.TURNPassword
}
return cfg, nil
}
func readFirstExistingFile(paths ...string) ([]byte, error) {
candidates := make([]string, 0, len(paths)*8)
seen := make(map[string]struct{})
cwd, err := os.Getwd()
if err == nil {
current := cwd
for {
for _, path := range paths {
candidate := filepath.Clean(filepath.Join(current, path))
if _, ok := seen[candidate]; ok {
continue
}
seen[candidate] = struct{}{}
candidates = append(candidates, candidate)
}
parent := filepath.Dir(current)
if parent == current {
break
}
current = parent
}
}
for _, path := range paths {
candidate := filepath.Clean(path)
if _, ok := seen[candidate]; ok {
continue
}
seen[candidate] = struct{}{}
candidates = append(candidates, candidate)
}
for _, candidate := range candidates {
data, err := os.ReadFile(candidate)
if err == nil {
return data, nil
}
if !errors.Is(err, os.ErrNotExist) {
return nil, err
}
}
return nil, os.ErrNotExist
}

View File

@@ -0,0 +1,617 @@
package storage
import (
"context"
"database/sql"
"encoding/json"
"errors"
"os"
"path/filepath"
"strings"
"time"
"filefast/backend/internal/config"
"filefast/backend/internal/model"
"filefast/backend/internal/store"
"golang.org/x/crypto/bcrypt"
_ "modernc.org/sqlite"
)
type SQLiteClient struct {
db *sql.DB
path string
}
func NewSQLiteClient(cfg config.SQLiteConfig) (*SQLiteClient, error) {
path := strings.TrimSpace(cfg.Path)
if path == "" {
return nil, errors.New("sqlite path is required")
}
if dir := filepath.Dir(path); dir != "" && dir != "." {
if err := os.MkdirAll(dir, 0o755); err != nil {
return nil, err
}
}
db, err := sql.Open("sqlite", path)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(1)
db.SetMaxIdleConns(1)
db.SetConnMaxLifetime(0)
client := &SQLiteClient{
db: db,
path: path,
}
for _, pragma := range []string{
"PRAGMA journal_mode = WAL",
"PRAGMA foreign_keys = ON",
"PRAGMA busy_timeout = 5000",
} {
if _, err := db.Exec(pragma); err != nil {
_ = db.Close()
return nil, err
}
}
return client, nil
}
func (c *SQLiteClient) Ping(ctx context.Context) error {
return c.db.PingContext(ctx)
}
func (c *SQLiteClient) EnsureSchema(ctx context.Context) error {
sqlBytes, err := readFirstExistingFile(
filepath.Join("sql", "init_sqlite.sql"),
filepath.Join("backend", "sql", "init_sqlite.sql"),
)
if err != nil {
return err
}
_, err = c.db.ExecContext(ctx, string(sqlBytes))
return err
}
func (c *SQLiteClient) ResetOnlineDevices(ctx context.Context) error {
_, err := c.db.ExecContext(ctx, `UPDATE devices SET is_online = 0 WHERE is_online = 1`)
return err
}
func (c *SQLiteClient) ResetOperationalData(ctx context.Context) error {
tx, err := c.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
for _, statement := range []string{
`DELETE FROM fallback_objects`,
`DELETE FROM transfers`,
`DELETE FROM sessions`,
`DELETE FROM rooms`,
`DELETE FROM devices`,
} {
if _, err := tx.ExecContext(ctx, statement); err != nil {
return err
}
}
return tx.Commit()
}
func (c *SQLiteClient) LoadSnapshot(ctx context.Context, fallbackRuntime model.RuntimeConfig) (store.Snapshot, error) {
snapshot := store.Snapshot{
Runtime: &fallbackRuntime,
}
devices, err := c.loadDevices(ctx)
if err != nil {
return store.Snapshot{}, err
}
snapshot.Devices = devices
rooms, err := c.loadRooms(ctx)
if err != nil {
return store.Snapshot{}, err
}
snapshot.Rooms = rooms
transfers, err := c.loadTransfers(ctx)
if err != nil {
return store.Snapshot{}, err
}
snapshot.Transfers = transfers
fallbackObjects, err := c.loadFallbackObjects(ctx)
if err != nil {
return store.Snapshot{}, err
}
snapshot.FallbackObjects = fallbackObjects
runtimeCfg, err := c.loadRuntimeConfig(ctx, fallbackRuntime)
if err != nil {
return store.Snapshot{}, err
}
snapshot.Runtime = &runtimeCfg
return snapshot, nil
}
func (c *SQLiteClient) PersistDevice(ctx context.Context, device model.Device) error {
_, err := c.db.ExecContext(ctx, `
INSERT INTO devices (
id, device_code, name, type, user_agent, network_group_key, public_ip_hash,
is_online, last_seen_at, created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
device_code = excluded.device_code,
name = excluded.name,
type = excluded.type,
user_agent = excluded.user_agent,
network_group_key = excluded.network_group_key,
public_ip_hash = excluded.public_ip_hash,
is_online = excluded.is_online,
last_seen_at = excluded.last_seen_at,
updated_at = excluded.updated_at
`,
device.ID,
device.ID,
device.Name,
device.Type,
nullableString(device.UserAgent),
nullableString(device.NetworkGroupKey),
nullableString(device.PublicIPHash),
boolToInt(device.IsOnline),
encodeTime(device.LastSeenAt),
encodeTime(device.CreatedAt),
encodeTime(time.Now()),
)
return err
}
func (c *SQLiteClient) PersistRoom(ctx context.Context, room model.Room) error {
_, err := c.db.ExecContext(ctx, `
INSERT INTO rooms (
code, creator_device_id, joiner_device_id, status, created_at, expires_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(code) DO UPDATE SET
creator_device_id = excluded.creator_device_id,
joiner_device_id = excluded.joiner_device_id,
status = excluded.status,
expires_at = excluded.expires_at,
updated_at = excluded.updated_at
`,
room.Code,
room.CreatorDeviceID,
nullableString(room.JoinerDeviceID),
room.Status,
encodeTime(room.CreatedAt),
encodeTime(room.ExpiresAt),
encodeTime(time.Now()),
)
return err
}
func (c *SQLiteClient) PersistTransfer(ctx context.Context, transfer model.Transfer) error {
_, err := c.db.ExecContext(ctx, `
INSERT INTO transfers (
id, session_id, kind, name, content, size_bytes,
sender_device_id, receiver_device_id, transfer_strategy, current_channel,
fallback_allowed, final_status, fallback_reason, object_key, expires_at,
created_at, updated_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(id) DO UPDATE SET
session_id = excluded.session_id,
kind = excluded.kind,
name = excluded.name,
content = excluded.content,
size_bytes = excluded.size_bytes,
sender_device_id = excluded.sender_device_id,
receiver_device_id = excluded.receiver_device_id,
transfer_strategy = excluded.transfer_strategy,
current_channel = excluded.current_channel,
fallback_allowed = excluded.fallback_allowed,
final_status = excluded.final_status,
fallback_reason = excluded.fallback_reason,
object_key = excluded.object_key,
expires_at = excluded.expires_at,
updated_at = excluded.updated_at
`,
transfer.ID,
nullableString(transfer.SessionID),
transfer.Kind,
transfer.Name,
nullableString(transfer.Content),
transfer.SizeBytes,
transfer.SenderDeviceID,
transfer.ReceiverDeviceID,
transfer.TransferStrategy,
transfer.CurrentChannel,
boolToInt(transfer.FallbackAllowed),
transfer.FinalStatus,
nullableString(transfer.FallbackReason),
nullableString(transfer.ObjectKey),
nullableTime(transfer.ExpiresAt),
encodeTime(transfer.CreatedAt),
encodeTime(transfer.UpdatedAt),
)
return err
}
func (c *SQLiteClient) PersistFallbackObject(ctx context.Context, object model.FallbackObject) error {
_, err := c.db.ExecContext(ctx, `
INSERT INTO fallback_objects (
transfer_id, bucket, object_key, size_bytes, cleanup_state, created_at, expires_at, cleaned_at
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(transfer_id) DO UPDATE SET
bucket = excluded.bucket,
object_key = excluded.object_key,
size_bytes = excluded.size_bytes,
cleanup_state = excluded.cleanup_state,
expires_at = excluded.expires_at,
cleaned_at = excluded.cleaned_at
`,
object.TransferID,
"filefast-fallback",
object.ObjectKey,
object.SizeBytes,
object.CleanupState,
encodeTime(object.CreatedAt),
encodeTime(object.ExpiresAt),
nullableTime(object.CleanedAt),
)
return err
}
func (c *SQLiteClient) PersistRuntimeConfig(ctx context.Context, runtime model.RuntimeConfig) error {
payload, err := json.Marshal(runtime)
if err != nil {
return err
}
_, err = c.db.ExecContext(ctx, `
INSERT INTO system_configs (config_key, config_value, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(config_key) DO UPDATE SET
config_value = excluded.config_value,
updated_at = excluded.updated_at
`, runtimeConfigKey, string(payload), encodeTime(time.Now()))
return err
}
func (c *SQLiteClient) EnsureAdminUser(ctx context.Context, username, password string) error {
username = strings.TrimSpace(username)
password = strings.TrimSpace(password)
if username == "" || password == "" {
return nil
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return err
}
_, err = c.db.ExecContext(ctx, `
INSERT INTO admin_users (username, password_hash, is_active, created_at, updated_at)
VALUES (?, ?, 1, ?, ?)
ON CONFLICT(username) DO NOTHING
`, username, string(hash), encodeTime(time.Now()), encodeTime(time.Now()))
return err
}
func (c *SQLiteClient) ValidateAdminCredentials(ctx context.Context, username, password string) (bool, error) {
var passwordHash string
var isActive int
err := c.db.QueryRowContext(ctx, `
SELECT password_hash, is_active
FROM admin_users
WHERE username = ?
`, strings.TrimSpace(username)).Scan(&passwordHash, &isActive)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
return false, err
}
if isActive == 0 {
return false, nil
}
return bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(strings.TrimSpace(password))) == nil, nil
}
func (c *SQLiteClient) Close() error {
return c.db.Close()
}
func (c *SQLiteClient) loadDevices(ctx context.Context) ([]model.Device, error) {
rows, err := c.db.QueryContext(ctx, `
SELECT
id,
name,
type,
COALESCE(user_agent, ''),
COALESCE(network_group_key, ''),
COALESCE(public_ip_hash, ''),
is_online,
COALESCE(last_seen_at, created_at),
created_at
FROM devices
`)
if err != nil {
return nil, err
}
defer rows.Close()
var devices []model.Device
for rows.Next() {
var (
device model.Device
isOnline int
lastSeen string
createdAt string
)
if err := rows.Scan(
&device.ID,
&device.Name,
&device.Type,
&device.UserAgent,
&device.NetworkGroupKey,
&device.PublicIPHash,
&isOnline,
&lastSeen,
&createdAt,
); err != nil {
return nil, err
}
device.IsOnline = isOnline != 0
device.LastSeenAt, err = decodeTime(lastSeen)
if err != nil {
return nil, err
}
device.CreatedAt, err = decodeTime(createdAt)
if err != nil {
return nil, err
}
devices = append(devices, device)
}
return devices, rows.Err()
}
func (c *SQLiteClient) loadRooms(ctx context.Context) ([]model.Room, error) {
rows, err := c.db.QueryContext(ctx, `
SELECT
code,
creator_device_id,
COALESCE(joiner_device_id, ''),
status,
created_at,
expires_at
FROM rooms
`)
if err != nil {
return nil, err
}
defer rows.Close()
var rooms []model.Room
for rows.Next() {
var (
room model.Room
createdAt string
expiresAt string
)
if err := rows.Scan(
&room.Code,
&room.CreatorDeviceID,
&room.JoinerDeviceID,
&room.Status,
&createdAt,
&expiresAt,
); err != nil {
return nil, err
}
room.CreatedAt, err = decodeTime(createdAt)
if err != nil {
return nil, err
}
room.ExpiresAt, err = decodeTime(expiresAt)
if err != nil {
return nil, err
}
rooms = append(rooms, room)
}
return rooms, rows.Err()
}
func (c *SQLiteClient) loadTransfers(ctx context.Context) ([]model.Transfer, error) {
rows, err := c.db.QueryContext(ctx, `
SELECT
id,
COALESCE(session_id, ''),
kind,
name,
COALESCE(content, ''),
size_bytes,
sender_device_id,
receiver_device_id,
transfer_strategy,
current_channel,
fallback_allowed,
final_status,
created_at,
updated_at,
COALESCE(fallback_reason, ''),
COALESCE(object_key, ''),
expires_at
FROM transfers
`)
if err != nil {
return nil, err
}
defer rows.Close()
var transfers []model.Transfer
for rows.Next() {
var (
transfer model.Transfer
createdAt string
updatedAt string
expiresAt sql.NullString
canFallback int
)
if err := rows.Scan(
&transfer.ID,
&transfer.SessionID,
&transfer.Kind,
&transfer.Name,
&transfer.Content,
&transfer.SizeBytes,
&transfer.SenderDeviceID,
&transfer.ReceiverDeviceID,
&transfer.TransferStrategy,
&transfer.CurrentChannel,
&canFallback,
&transfer.FinalStatus,
&createdAt,
&updatedAt,
&transfer.FallbackReason,
&transfer.ObjectKey,
&expiresAt,
); err != nil {
return nil, err
}
transfer.FallbackAllowed = canFallback != 0
transfer.CreatedAt, err = decodeTime(createdAt)
if err != nil {
return nil, err
}
transfer.UpdatedAt, err = decodeTime(updatedAt)
if err != nil {
return nil, err
}
if expiresAt.Valid && strings.TrimSpace(expiresAt.String) != "" {
parsed, err := decodeTime(expiresAt.String)
if err != nil {
return nil, err
}
transfer.ExpiresAt = &parsed
}
transfers = append(transfers, transfer)
}
return transfers, rows.Err()
}
func (c *SQLiteClient) loadFallbackObjects(ctx context.Context) ([]model.FallbackObject, error) {
rows, err := c.db.QueryContext(ctx, `
SELECT
transfer_id,
object_key,
size_bytes,
created_at,
expires_at,
cleaned_at,
cleanup_state
FROM fallback_objects
`)
if err != nil {
return nil, err
}
defer rows.Close()
var objects []model.FallbackObject
for rows.Next() {
var (
object model.FallbackObject
createdAt string
expiresAt string
cleanedAt sql.NullString
)
if err := rows.Scan(
&object.TransferID,
&object.ObjectKey,
&object.SizeBytes,
&createdAt,
&expiresAt,
&cleanedAt,
&object.CleanupState,
); err != nil {
return nil, err
}
object.CreatedAt, err = decodeTime(createdAt)
if err != nil {
return nil, err
}
object.ExpiresAt, err = decodeTime(expiresAt)
if err != nil {
return nil, err
}
if cleanedAt.Valid && strings.TrimSpace(cleanedAt.String) != "" {
parsed, err := decodeTime(cleanedAt.String)
if err != nil {
return nil, err
}
object.CleanedAt = &parsed
}
objects = append(objects, object)
}
return objects, rows.Err()
}
func (c *SQLiteClient) loadRuntimeConfig(ctx context.Context, fallback model.RuntimeConfig) (model.RuntimeConfig, error) {
var raw string
err := c.db.QueryRowContext(ctx, `
SELECT config_value
FROM system_configs
WHERE config_key = ?
`, runtimeConfigKey).Scan(&raw)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return fallback, nil
}
return model.RuntimeConfig{}, err
}
return decodeRuntimeConfig([]byte(raw), fallback)
}
func encodeTime(value time.Time) string {
return value.UTC().Format(time.RFC3339Nano)
}
func nullableTime(value *time.Time) any {
if value == nil {
return nil
}
return encodeTime(*value)
}
func decodeTime(value string) (time.Time, error) {
return time.Parse(time.RFC3339Nano, value)
}
func nullableString(value string) any {
if strings.TrimSpace(value) == "" {
return nil
}
return value
}
func boolToInt(value bool) int {
if value {
return 1
}
return 0
}

View File

@@ -0,0 +1,478 @@
package store
import (
"context"
"sort"
"sync"
"time"
"filefast/backend/internal/model"
)
type Persistence interface {
PersistDevice(context.Context, model.Device) error
PersistRoom(context.Context, model.Room) error
PersistTransfer(context.Context, model.Transfer) error
PersistFallbackObject(context.Context, model.FallbackObject) error
PersistRuntimeConfig(context.Context, model.RuntimeConfig) error
}
type Snapshot struct {
Devices []model.Device
Rooms []model.Room
Transfers []model.Transfer
FallbackObjects []model.FallbackObject
Runtime *model.RuntimeConfig
}
type MemoryStore struct {
mu sync.RWMutex
devices map[string]model.Device
rooms map[string]model.Room
transfers map[string]model.Transfer
fallbackObjects map[string]model.FallbackObject
adminSessions map[string]model.AdminSession
deviceSessions map[string]model.DeviceSession
runtime model.RuntimeConfig
persistence Persistence
persistTimeout time.Duration
onPersistError func(kind, id string, err error)
}
const (
activeDeviceWindow = 2 * time.Minute
activeTransferWindow = 30 * time.Minute
recentTerminalTransferWind = 24 * time.Hour
)
func NewMemoryStore(runtime model.RuntimeConfig) *MemoryStore {
return &MemoryStore{
devices: make(map[string]model.Device),
rooms: make(map[string]model.Room),
transfers: make(map[string]model.Transfer),
fallbackObjects: make(map[string]model.FallbackObject),
adminSessions: make(map[string]model.AdminSession),
deviceSessions: make(map[string]model.DeviceSession),
runtime: runtime,
}
}
func (s *MemoryStore) UpsertDevice(device model.Device) model.Device {
s.mu.Lock()
s.devices[device.ID] = device
s.mu.Unlock()
s.persistDevice(device)
return device
}
func (s *MemoryStore) GetDevice(id string) (model.Device, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
device, ok := s.devices[id]
return device, ok
}
func (s *MemoryStore) ListDevices() []model.Device {
s.mu.RLock()
defer s.mu.RUnlock()
devices := make([]model.Device, 0, len(s.devices))
for _, device := range s.devices {
devices = append(devices, device)
}
return devices
}
func (s *MemoryStore) UpsertRoom(room model.Room) model.Room {
s.mu.Lock()
s.rooms[room.Code] = room
s.mu.Unlock()
s.persistRoom(room)
return room
}
func (s *MemoryStore) GetRoom(code string) (model.Room, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
room, ok := s.rooms[code]
return room, ok
}
func (s *MemoryStore) UpsertTransfer(transfer model.Transfer) model.Transfer {
s.mu.Lock()
s.transfers[transfer.ID] = transfer
s.mu.Unlock()
s.persistTransfer(transfer)
return transfer
}
func (s *MemoryStore) GetTransfer(id string) (model.Transfer, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
transfer, ok := s.transfers[id]
return transfer, ok
}
func (s *MemoryStore) ListRecentTransfers(limit int) []model.Transfer {
s.mu.RLock()
defer s.mu.RUnlock()
now := time.Now()
transfers := make([]model.Transfer, 0, len(s.transfers))
for _, transfer := range s.transfers {
if !isTransferVisible(transfer, now) {
continue
}
transfers = append(transfers, transfer)
}
sort.Slice(transfers, func(i, j int) bool {
return transfers[i].UpdatedAt.After(transfers[j].UpdatedAt)
})
if limit > 0 && len(transfers) > limit {
return transfers[:limit]
}
return transfers
}
func (s *MemoryStore) ListPendingFallbackDownloads(receiverDeviceID string, limit int) []model.PendingFallbackDownload {
s.mu.RLock()
defer s.mu.RUnlock()
if receiverDeviceID == "" {
return nil
}
now := time.Now()
downloads := make([]model.PendingFallbackDownload, 0, len(s.transfers))
for _, transfer := range s.transfers {
if transfer.ReceiverDeviceID != receiverDeviceID || transfer.ObjectKey == "" || transfer.FinalStatus != model.TransferCompleted {
continue
}
if transfer.ExpiresAt == nil || !transfer.ExpiresAt.After(now) {
continue
}
object, ok := s.fallbackObjects[transfer.ID]
if !ok || object.CleanedAt != nil || object.ExpiresAt.Before(now) || object.CleanupState != "ready" {
continue
}
downloads = append(downloads, model.PendingFallbackDownload{
TransferID: transfer.ID,
Name: transfer.Name,
SizeBytes: transfer.SizeBytes,
CreatedAt: transfer.CreatedAt,
ExpiresAt: object.ExpiresAt,
DownloadPath: "/api/transfers/" + transfer.ID + "/fallback/download",
SenderDeviceID: transfer.SenderDeviceID,
})
}
sort.Slice(downloads, func(i, j int) bool {
return downloads[i].CreatedAt.After(downloads[j].CreatedAt)
})
if limit > 0 && len(downloads) > limit {
return downloads[:limit]
}
return downloads
}
func (s *MemoryStore) SaveFallbackObject(object model.FallbackObject) model.FallbackObject {
s.mu.Lock()
s.fallbackObjects[object.TransferID] = object
s.mu.Unlock()
s.persistFallbackObject(object)
return object
}
func (s *MemoryStore) GetFallbackObject(transferID string) (model.FallbackObject, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
object, ok := s.fallbackObjects[transferID]
return object, ok
}
func (s *MemoryStore) ListExpiredFallbackObjects(now time.Time) []model.FallbackObject {
s.mu.RLock()
defer s.mu.RUnlock()
var objects []model.FallbackObject
for _, object := range s.fallbackObjects {
if object.CleanedAt == nil && !object.ExpiresAt.After(now) {
objects = append(objects, object)
}
}
return objects
}
func (s *MemoryStore) SaveAdminSession(session model.AdminSession) model.AdminSession {
s.mu.Lock()
defer s.mu.Unlock()
s.adminSessions[session.Token] = session
return session
}
func (s *MemoryStore) SaveDeviceSession(session model.DeviceSession) model.DeviceSession {
s.mu.Lock()
defer s.mu.Unlock()
s.deviceSessions[session.DeviceID] = session
return session
}
func (s *MemoryStore) HasAdminSession(token string) bool {
s.mu.RLock()
defer s.mu.RUnlock()
_, ok := s.adminSessions[token]
return ok
}
func (s *MemoryStore) ValidateDeviceSession(deviceID, token string) bool {
s.mu.RLock()
session, ok := s.deviceSessions[deviceID]
s.mu.RUnlock()
if !ok {
return false
}
if session.ExpiresAt.Before(time.Now()) {
s.mu.Lock()
delete(s.deviceSessions, deviceID)
s.mu.Unlock()
return false
}
return session.Token == token
}
func (s *MemoryStore) RuntimeConfig() model.RuntimeConfig {
s.mu.RLock()
defer s.mu.RUnlock()
return s.runtime
}
func (s *MemoryStore) UpdateRuntimeConfig(runtime model.RuntimeConfig) model.RuntimeConfig {
s.mu.Lock()
s.runtime = runtime
s.mu.Unlock()
s.persistRuntimeConfig(runtime)
return runtime
}
func (s *MemoryStore) SnapshotStats() map[string]int {
s.mu.RLock()
defer s.mu.RUnlock()
now := time.Now()
onlineDevices := 0
activeDevices := 0
for _, device := range s.devices {
if isDeviceActive(device, now) {
activeDevices++
}
if device.IsOnline && isDeviceActive(device, now) {
onlineDevices++
}
}
waitingRooms := 0
for _, room := range s.rooms {
if room.Status == model.RoomStatusWaiting && room.ExpiresAt.After(now) {
waitingRooms++
}
}
fallbackPending := 0
for _, object := range s.fallbackObjects {
if object.CleanedAt == nil && object.CleanupState == "ready" {
fallbackPending++
}
}
validTransfers := 0
for _, transfer := range s.transfers {
if isTransferVisible(transfer, now) {
validTransfers++
}
}
return map[string]int{
"devices_total": activeDevices,
"devices_online": onlineDevices,
"rooms_waiting": waitingRooms,
"transfers_total": validTransfers,
"transfers_cumulative": len(s.transfers),
"fallback_pending": fallbackPending,
"admin_sessions": len(s.adminSessions),
}
}
func (s *MemoryStore) SnapshotMinIOStorage() model.MinIOStorageOverview {
s.mu.RLock()
defer s.mu.RUnlock()
now := time.Now()
usedBytes := int64(0)
objectCount := 0
for _, object := range s.fallbackObjects {
if object.CleanedAt != nil || !object.ExpiresAt.After(now) || object.CleanupState != "ready" {
continue
}
usedBytes += object.SizeBytes
objectCount++
}
capacityBytes := s.runtime.MinIOCapacityBytes
if capacityBytes < 0 {
capacityBytes = 0
}
remainingBytes := capacityBytes - usedBytes
if remainingBytes < 0 {
remainingBytes = 0
}
usagePercent := 0
if capacityBytes > 0 {
usagePercent = int((usedBytes * 100) / capacityBytes)
if usagePercent > 100 {
usagePercent = 100
}
}
return model.MinIOStorageOverview{
Enabled: s.runtime.MinIOFallbackEnabled,
UsedBytes: usedBytes,
CapacityBytes: capacityBytes,
RemainingBytes: remainingBytes,
UsagePercent: usagePercent,
ObjectCount: objectCount,
}
}
func isDeviceActive(device model.Device, now time.Time) bool {
if device.LastSeenAt.IsZero() {
return false
}
return device.LastSeenAt.After(now.Add(-activeDeviceWindow))
}
func isTransferVisible(transfer model.Transfer, now time.Time) bool {
if transfer.UpdatedAt.IsZero() {
transfer.UpdatedAt = transfer.CreatedAt
}
if isTerminalTransferStatus(transfer.FinalStatus) {
return transfer.UpdatedAt.After(now.Add(-recentTerminalTransferWind))
}
return transfer.UpdatedAt.After(now.Add(-activeTransferWindow))
}
func isTerminalTransferStatus(status string) bool {
switch status {
case model.TransferCompleted, model.TransferFailed, model.TransferCancelled:
return true
default:
return false
}
}
func (s *MemoryStore) SetPersistence(persistence Persistence, timeout time.Duration, onError func(kind, id string, err error)) {
s.mu.Lock()
defer s.mu.Unlock()
s.persistence = persistence
s.persistTimeout = timeout
s.onPersistError = onError
}
func (s *MemoryStore) LoadSnapshot(snapshot Snapshot) {
s.mu.Lock()
defer s.mu.Unlock()
s.devices = make(map[string]model.Device, len(snapshot.Devices))
for _, device := range snapshot.Devices {
s.devices[device.ID] = device
}
s.rooms = make(map[string]model.Room, len(snapshot.Rooms))
for _, room := range snapshot.Rooms {
s.rooms[room.Code] = room
}
s.transfers = make(map[string]model.Transfer, len(snapshot.Transfers))
for _, transfer := range snapshot.Transfers {
s.transfers[transfer.ID] = transfer
}
s.fallbackObjects = make(map[string]model.FallbackObject, len(snapshot.FallbackObjects))
for _, object := range snapshot.FallbackObjects {
s.fallbackObjects[object.TransferID] = object
}
if snapshot.Runtime != nil {
s.runtime = *snapshot.Runtime
}
}
func (s *MemoryStore) persistDevice(device model.Device) {
s.persist(device.ID, "device", func(ctx context.Context, persistence Persistence) error {
return persistence.PersistDevice(ctx, device)
})
}
func (s *MemoryStore) persistRoom(room model.Room) {
s.persist(room.Code, "room", func(ctx context.Context, persistence Persistence) error {
return persistence.PersistRoom(ctx, room)
})
}
func (s *MemoryStore) persistTransfer(transfer model.Transfer) {
s.persist(transfer.ID, "transfer", func(ctx context.Context, persistence Persistence) error {
return persistence.PersistTransfer(ctx, transfer)
})
}
func (s *MemoryStore) persistFallbackObject(object model.FallbackObject) {
s.persist(object.TransferID, "fallback_object", func(ctx context.Context, persistence Persistence) error {
return persistence.PersistFallbackObject(ctx, object)
})
}
func (s *MemoryStore) persistRuntimeConfig(runtime model.RuntimeConfig) {
s.persist("transfer_policy", "runtime_config", func(ctx context.Context, persistence Persistence) error {
return persistence.PersistRuntimeConfig(ctx, runtime)
})
}
func (s *MemoryStore) persist(id, kind string, fn func(context.Context, Persistence) error) {
s.mu.RLock()
persistence := s.persistence
timeout := s.persistTimeout
onError := s.onPersistError
s.mu.RUnlock()
if persistence == nil {
return
}
if timeout <= 0 {
timeout = 5 * time.Second
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := fn(ctx, persistence); err != nil && onError != nil {
onError(kind, id, err)
}
}

368
backend/internal/ws/hub.go Normal file
View File

@@ -0,0 +1,368 @@
package ws
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"net/url"
"strings"
"time"
"filefast/backend/internal/model"
"filefast/backend/internal/service"
"filefast/backend/internal/storage"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
const (
writeWait = 10 * time.Second
pongWait = 60 * time.Second
pingPeriod = (pongWait * 9) / 10
maxMessageSize = 8 * 1024 * 1024
)
type Hub struct {
logger *slog.Logger
deviceService *service.DeviceService
backplane realtimeBackplane
instanceID string
clients map[string]*Client
register chan *Client
unregister chan *Client
relay chan relayMessage
}
type Client struct {
hub *Hub
deviceID string
conn *websocket.Conn
send chan []byte
}
type relayMessage struct {
targetDeviceID string
payload []byte
}
type inboundMessage struct {
Type string `json:"type"`
TargetDeviceID string `json:"target_device_id"`
Payload json.RawMessage `json:"payload"`
}
type backplaneEnvelope struct {
Source string `json:"source"`
TargetDeviceID string `json:"target_device_id,omitempty"`
Payload json.RawMessage `json:"payload"`
}
type realtimeBackplane interface {
PublishRealtime(context.Context, string, []byte) error
SubscribeRealtime(context.Context, ...string) (<-chan storage.RealtimeMessage, error)
}
var upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
origin := strings.TrimSpace(r.Header.Get("Origin"))
if origin == "" {
return true
}
parsed, err := url.Parse(origin)
if err != nil {
return false
}
return originHostMatchesRequest(parsed.Host, r)
},
}
func NewHub(logger *slog.Logger, deviceService *service.DeviceService, backplane realtimeBackplane) *Hub {
return &Hub{
logger: logger,
deviceService: deviceService,
backplane: backplane,
instanceID: fmt.Sprintf("%d", time.Now().UnixNano()),
clients: make(map[string]*Client),
register: make(chan *Client),
unregister: make(chan *Client),
relay: make(chan relayMessage),
}
}
func (h *Hub) Run() {
var backplaneMessages <-chan storage.RealtimeMessage
if h.backplane != nil {
messages, err := h.backplane.SubscribeRealtime(context.Background(), storage.RedisRealtimeRelayChannel, storage.RedisRealtimePresenceChannel)
if err != nil {
h.logger.Warn("failed to subscribe redis realtime backplane", "error", err)
} else {
backplaneMessages = messages
}
}
for {
select {
case client := <-h.register:
h.clients[client.deviceID] = client
h.deviceService.SetOnline(client.deviceID, true)
h.broadcastPresence(client.deviceID, true)
case client := <-h.unregister:
if existing, ok := h.clients[client.deviceID]; ok && existing == client {
delete(h.clients, client.deviceID)
close(client.send)
h.deviceService.SetOnline(client.deviceID, false)
h.broadcastPresence(client.deviceID, false)
}
case message := <-h.relay:
h.dispatchRelay(message.targetDeviceID, message.payload)
case message, ok := <-backplaneMessages:
if !ok {
backplaneMessages = nil
continue
}
h.handleBackplaneMessage(message)
}
}
}
func (h *Hub) Handle(c *gin.Context) {
deviceID := c.Query("deviceId")
token := strings.TrimSpace(c.Query("deviceToken"))
if deviceID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "deviceId is required"})
return
}
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"error": "deviceToken is required"})
return
}
if !h.deviceService.ValidateSession(deviceID, token) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid device credentials"})
return
}
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to upgrade websocket"})
return
}
client := &Client{
hub: h,
deviceID: deviceID,
conn: conn,
send: make(chan []byte, 32),
}
h.register <- client
go client.writePump()
client.readPump()
}
func (h *Hub) broadcastPresence(deviceID string, online bool) {
envelope := model.SignalEnvelope{
Type: "presence.update",
DeviceID: deviceID,
Payload: map[string]any{
"online": online,
},
}
data, err := json.Marshal(envelope)
if err != nil {
h.logger.Warn("failed to marshal presence update", "error", err)
return
}
h.broadcastPresencePayload(data)
h.publishBackplane(storage.RedisRealtimePresenceChannel, "", data)
}
func (h *Hub) broadcastPresencePayload(data []byte) {
for _, client := range h.clients {
client.send <- data
}
}
func (h *Hub) dispatchRelay(targetDeviceID string, payload []byte) {
if client, ok := h.clients[targetDeviceID]; ok {
client.send <- payload
}
h.publishBackplane(storage.RedisRealtimeRelayChannel, targetDeviceID, payload)
}
func (h *Hub) publishBackplane(channel, targetDeviceID string, payload []byte) {
if h.backplane == nil {
return
}
envelope, err := json.Marshal(backplaneEnvelope{
Source: h.instanceID,
TargetDeviceID: targetDeviceID,
Payload: payload,
})
if err != nil {
h.logger.Warn("failed to marshal backplane envelope", "channel", channel, "error", err)
return
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := h.backplane.PublishRealtime(ctx, channel, envelope); err != nil {
h.logger.Warn("failed to publish realtime backplane message", "channel", channel, "error", err)
}
}
func (h *Hub) handleBackplaneMessage(message storage.RealtimeMessage) {
var envelope backplaneEnvelope
if err := json.Unmarshal(message.Payload, &envelope); err != nil {
h.logger.Warn("failed to decode realtime backplane message", "channel", message.Channel, "error", err)
return
}
if envelope.Source == h.instanceID {
return
}
switch message.Channel {
case storage.RedisRealtimeRelayChannel:
if client, ok := h.clients[envelope.TargetDeviceID]; ok {
client.send <- envelope.Payload
}
case storage.RedisRealtimePresenceChannel:
h.broadcastPresencePayload(envelope.Payload)
}
}
func (c *Client) readPump() {
defer func() {
c.hub.unregister <- c
c.conn.Close()
}()
c.conn.SetReadLimit(maxMessageSize)
c.conn.SetReadDeadline(time.Now().Add(pongWait))
c.conn.SetPongHandler(func(string) error {
c.conn.SetReadDeadline(time.Now().Add(pongWait))
return nil
})
for {
var inbound inboundMessage
if err := c.conn.ReadJSON(&inbound); err != nil {
return
}
envelope := model.SignalEnvelope{
Type: inbound.Type,
DeviceID: c.deviceID,
TargetDeviceID: inbound.TargetDeviceID,
Payload: inbound.Payload,
}
data, err := json.Marshal(envelope)
if err != nil {
continue
}
c.hub.relay <- relayMessage{
targetDeviceID: inbound.TargetDeviceID,
payload: data,
}
}
}
func (c *Client) writePump() {
ticker := time.NewTicker(pingPeriod)
defer func() {
ticker.Stop()
c.conn.Close()
}()
for {
select {
case message, ok := <-c.send:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if !ok {
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
return
}
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
return
}
case <-ticker.C:
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
return
}
}
}
}
func originHostMatchesRequest(originHost string, r *http.Request) bool {
requestHosts := []string{
strings.TrimSpace(r.Host),
strings.TrimSpace(r.Header.Get("X-Forwarded-Host")),
}
originName, err := normalizeHost(originHost)
if err != nil {
return false
}
for _, host := range requestHosts {
if host == "" {
continue
}
requestName, err := normalizeHost(host)
if err != nil {
continue
}
if requestName == originName {
return true
}
}
return false
}
func normalizeHost(host string) (string, error) {
host = strings.TrimSpace(host)
if host == "" {
return "", fmt.Errorf("empty host")
}
if strings.Contains(host, "://") {
parsed, err := url.Parse(host)
if err != nil {
return "", err
}
host = parsed.Host
}
name, _, err := net.SplitHostPort(host)
if err == nil {
host = name
} else if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
host = strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
}
host = strings.ToLower(strings.TrimSpace(host))
switch host {
case "127.0.0.1", "::1":
return "localhost", nil
default:
return host, nil
}
}