first commit
This commit is contained in:
252
backend/internal/config/config.go
Normal file
252
backend/internal/config/config.go
Normal file
@@ -0,0 +1,252 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/model"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
HTTPAddress string
|
||||
LogLevel slog.Level
|
||||
RoomTTL time.Duration
|
||||
Admin AdminConfig
|
||||
SQLite SQLiteConfig
|
||||
Redis RedisConfig
|
||||
MinIO MinIOConfig
|
||||
Runtime model.RuntimeConfig
|
||||
}
|
||||
|
||||
type AdminConfig struct {
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type MinIOConfig struct {
|
||||
Endpoint string
|
||||
AccessKey string
|
||||
SecretKey string
|
||||
UseSSL bool
|
||||
Bucket string
|
||||
PresignExpiry time.Duration
|
||||
Retention time.Duration
|
||||
UsageAlertLevel int
|
||||
}
|
||||
|
||||
type RedisConfig struct {
|
||||
Addr string
|
||||
Password string
|
||||
DB int
|
||||
}
|
||||
|
||||
type SQLiteConfig struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
func Load() Config {
|
||||
loadDotEnv()
|
||||
retentionHours := envInt("MINIO_RETENTION_HOURS", 2)
|
||||
capacityGB := envInt("MINIO_CAPACITY_GB", 120)
|
||||
return Config{
|
||||
HTTPAddress: envString("HTTP_ADDR", ":8080"),
|
||||
LogLevel: parseLogLevel(envString("LOG_LEVEL", "info")),
|
||||
RoomTTL: time.Duration(envInt("ROOM_TTL_SECONDS", 300)) * time.Second,
|
||||
Admin: AdminConfig{
|
||||
Username: envString("ADMIN_USERNAME", ""),
|
||||
Password: envString("ADMIN_PASSWORD", ""),
|
||||
},
|
||||
SQLite: SQLiteConfig{
|
||||
Path: envString("SQLITE_PATH", filepath.Join("backend", "data", "filefast.db")),
|
||||
},
|
||||
Redis: RedisConfig{
|
||||
Addr: envString("REDIS_ADDR", "127.0.0.1:6379"),
|
||||
Password: envString("REDIS_PASSWORD", ""),
|
||||
DB: envInt("REDIS_DB", 0),
|
||||
},
|
||||
MinIO: MinIOConfig{
|
||||
Endpoint: envString("MINIO_ENDPOINT", ""),
|
||||
AccessKey: envString("MINIO_ACCESS_KEY", ""),
|
||||
SecretKey: envString("MINIO_SECRET_KEY", ""),
|
||||
UseSSL: envBool("MINIO_USE_SSL", false),
|
||||
Bucket: envString("MINIO_BUCKET", "filefast-fallback"),
|
||||
PresignExpiry: time.Duration(envInt("MINIO_PRESIGN_MINUTES", 30)) * time.Minute,
|
||||
Retention: time.Duration(retentionHours) * time.Hour,
|
||||
UsageAlertLevel: envInt("MINIO_USAGE_ALERT_PERCENT", 85),
|
||||
},
|
||||
Runtime: model.RuntimeConfig{
|
||||
MaxMinIOFallbackSizeBytes: int64(envInt("MAX_MINIO_FALLBACK_GB", 10)) * 1024 * 1024 * 1024,
|
||||
MinIOCapacityBytes: int64(capacityGB) * 1024 * 1024 * 1024,
|
||||
MinIORetentionHours: retentionHours,
|
||||
MinIOUsageAlertPercent: envInt("MINIO_USAGE_ALERT_PERCENT", 85),
|
||||
P2PConnectTimeoutSec: envInt("P2P_CONNECT_TIMEOUT_SEC", 15),
|
||||
TURNConnectTimeoutSec: envInt("TURN_CONNECT_TIMEOUT_SEC", 20),
|
||||
MinIOFallbackEnabled: true,
|
||||
TURNURLs: envCSV("TURN_URLS"),
|
||||
TURNUsername: envString("TURN_USERNAME", ""),
|
||||
TURNPassword: envString("TURN_PASSWORD", ""),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func loadDotEnv() {
|
||||
for _, candidate := range dotEnvCandidates() {
|
||||
if loadDotEnvFile(candidate) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dotEnvCandidates() []string {
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return []string{".env", filepath.Join("backend", ".env")}
|
||||
}
|
||||
|
||||
candidates := make([]string, 0, 16)
|
||||
seen := make(map[string]struct{})
|
||||
current := cwd
|
||||
|
||||
for {
|
||||
for _, name := range []string{".env", filepath.Join("backend", ".env")} {
|
||||
candidate := filepath.Clean(filepath.Join(current, name))
|
||||
if _, ok := seen[candidate]; ok {
|
||||
continue
|
||||
}
|
||||
seen[candidate] = struct{}{}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
|
||||
parent := filepath.Dir(current)
|
||||
if parent == current {
|
||||
break
|
||||
}
|
||||
current = parent
|
||||
}
|
||||
|
||||
// Preserve the most local candidates first, but keep the historical relative
|
||||
// fallbacks at the end for compatibility.
|
||||
for _, legacy := range []string{".env", filepath.Join("backend", ".env")} {
|
||||
if !slices.Contains(candidates, legacy) {
|
||||
candidates = append(candidates, legacy)
|
||||
}
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
func loadDotEnvFile(path string) bool {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(strings.TrimPrefix(scanner.Text(), "\uFEFF"))
|
||||
if line == "" || strings.HasPrefix(line, "#") {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(line, "export ") {
|
||||
line = strings.TrimSpace(strings.TrimPrefix(line, "export "))
|
||||
}
|
||||
|
||||
key, value, ok := strings.Cut(line, "=")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
key = strings.TrimSpace(key)
|
||||
value = strings.TrimSpace(value)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := os.LookupEnv(key); exists {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(value) >= 2 {
|
||||
if (strings.HasPrefix(value, "\"") && strings.HasSuffix(value, "\"")) ||
|
||||
(strings.HasPrefix(value, "'") && strings.HasSuffix(value, "'")) {
|
||||
value = value[1 : len(value)-1]
|
||||
}
|
||||
}
|
||||
|
||||
_ = os.Setenv(key, value)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func envString(key, fallback string) string {
|
||||
value := strings.TrimSpace(os.Getenv(key))
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func envInt(key string, fallback int) int {
|
||||
value := strings.TrimSpace(os.Getenv(key))
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
|
||||
parsed, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return fallback
|
||||
}
|
||||
|
||||
return parsed
|
||||
}
|
||||
|
||||
func envBool(key string, fallback bool) bool {
|
||||
value := strings.TrimSpace(strings.ToLower(os.Getenv(key)))
|
||||
if value == "" {
|
||||
return fallback
|
||||
}
|
||||
|
||||
return value == "1" || value == "true" || value == "yes"
|
||||
}
|
||||
|
||||
func envCSV(key string) []string {
|
||||
raw := strings.TrimSpace(os.Getenv(key))
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(raw, ",")
|
||||
values := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
value := strings.TrimSpace(part)
|
||||
if value != "" {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func parseLogLevel(level string) slog.Level {
|
||||
switch strings.ToLower(strings.TrimSpace(level)) {
|
||||
case "debug":
|
||||
return slog.LevelDebug
|
||||
case "warn":
|
||||
return slog.LevelWarn
|
||||
case "error":
|
||||
return slog.LevelError
|
||||
default:
|
||||
return slog.LevelInfo
|
||||
}
|
||||
}
|
||||
593
backend/internal/handler/http.go
Normal file
593
backend/internal/handler/http.go
Normal file
@@ -0,0 +1,593 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/config"
|
||||
"filefast/backend/internal/model"
|
||||
"filefast/backend/internal/service"
|
||||
"filefast/backend/internal/storage"
|
||||
"filefast/backend/internal/store"
|
||||
"filefast/backend/internal/ws"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type Dependencies struct {
|
||||
Config config.Config
|
||||
Logger *slog.Logger
|
||||
Store *store.MemoryStore
|
||||
DeviceService *service.DeviceService
|
||||
RoomService *service.RoomService
|
||||
TransferService *service.TransferService
|
||||
AdminService *service.AdminService
|
||||
MinIOClient *storage.MinIOClient
|
||||
Hub *ws.Hub
|
||||
StorageReady bool
|
||||
RedisReady bool
|
||||
}
|
||||
|
||||
type HTTPHandler struct {
|
||||
deps Dependencies
|
||||
}
|
||||
|
||||
func NewHTTPHandler(deps Dependencies) *HTTPHandler {
|
||||
return &HTTPHandler{deps: deps}
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) Router() *gin.Engine {
|
||||
router := gin.Default()
|
||||
|
||||
router.GET("/healthz", h.handleHealth)
|
||||
router.GET("/ws", h.deps.Hub.Handle)
|
||||
|
||||
api := router.Group("/api")
|
||||
{
|
||||
api.GET("/runtime/config", h.runtimeConfig)
|
||||
api.POST("/devices/register", h.registerDevice)
|
||||
api.POST("/admin/login", h.adminLogin)
|
||||
}
|
||||
|
||||
device := api.Group("/")
|
||||
device.Use(h.requireDevice())
|
||||
{
|
||||
device.POST("/devices/heartbeat", h.deviceHeartbeat)
|
||||
device.GET("/devices/candidates", h.listCandidates)
|
||||
device.GET("/devices/:id/pending-downloads", h.pendingFallbackDownloads)
|
||||
|
||||
device.POST("/rooms", h.createRoom)
|
||||
device.GET("/rooms/:code", h.getRoom)
|
||||
device.POST("/rooms/join", h.joinRoom)
|
||||
device.POST("/rooms/:code/cancel", h.cancelRoom)
|
||||
|
||||
device.POST("/transfers", h.createTransfer)
|
||||
device.PATCH("/transfers/:id/status", h.updateTransferStatus)
|
||||
device.POST("/transfers/:id/fallback/presign", h.presignFallback)
|
||||
device.PUT("/transfers/:id/fallback/upload", h.uploadFallback)
|
||||
device.GET("/transfers/:id/fallback/download", h.downloadFallback)
|
||||
}
|
||||
|
||||
admin := api.Group("/admin")
|
||||
admin.Use(h.requireAdmin())
|
||||
{
|
||||
admin.GET("/stats", h.adminStats)
|
||||
admin.GET("/config", h.adminConfig)
|
||||
admin.PUT("/config", h.updateAdminConfig)
|
||||
admin.GET("/transfers/recent", h.recentTransfers)
|
||||
}
|
||||
|
||||
return router
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) handleHealth(c *gin.Context) {
|
||||
status := "ok"
|
||||
if !h.deps.StorageReady || !h.deps.RedisReady {
|
||||
status = "degraded"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": status,
|
||||
"minio_enabled": h.deps.MinIOClient != nil && h.deps.MinIOClient.Enabled(),
|
||||
"storage_ready": h.deps.StorageReady,
|
||||
"redis_ready": h.deps.RedisReady,
|
||||
"turn_enabled": len(h.deps.Store.RuntimeConfig().TURNURLs) > 0,
|
||||
"room_ttl_sec": int(h.deps.Config.RoomTTL.Seconds()),
|
||||
"server_time_unix": time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) runtimeConfig(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"data": h.deps.Store.RuntimeConfig()})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) registerDevice(c *gin.Context) {
|
||||
var input service.RegisterDeviceInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
device, session := h.deps.DeviceService.Register(input, c.Request.UserAgent(), c.GetHeader("X-Device-Token"))
|
||||
c.JSON(http.StatusOK, gin.H{"data": gin.H{
|
||||
"id": device.ID,
|
||||
"name": device.Name,
|
||||
"type": device.Type,
|
||||
"user_agent": device.UserAgent,
|
||||
"network_group_key": device.NetworkGroupKey,
|
||||
"public_ip_hash": device.PublicIPHash,
|
||||
"is_online": device.IsOnline,
|
||||
"last_seen_at": device.LastSeenAt,
|
||||
"created_at": device.CreatedAt,
|
||||
"auth_token": session.Token,
|
||||
"auth_expires_at": session.ExpiresAt,
|
||||
}})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) deviceHeartbeat(c *gin.Context) {
|
||||
var input struct {
|
||||
DeviceID string `json:"device_id" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if !h.ensureAuthenticatedDevice(c, input.DeviceID) {
|
||||
return
|
||||
}
|
||||
|
||||
device, ok := h.deps.DeviceService.Heartbeat(input.DeviceID)
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "device not found"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": device})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) listCandidates(c *gin.Context) {
|
||||
deviceID := c.Query("deviceId")
|
||||
if deviceID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "deviceId is required"})
|
||||
return
|
||||
}
|
||||
if !h.ensureAuthenticatedDevice(c, deviceID) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": h.deps.DeviceService.ListCandidates(deviceID)})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) createRoom(c *gin.Context) {
|
||||
var input struct {
|
||||
CreatorDeviceID string `json:"creator_device_id" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !h.ensureAuthenticatedDevice(c, input.CreatorDeviceID) {
|
||||
return
|
||||
}
|
||||
|
||||
room, err := h.deps.RoomService.CreateRoom(input.CreatorDeviceID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": room})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) getRoom(c *gin.Context) {
|
||||
room, ok := h.deps.Store.GetRoom(c.Param("code"))
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "room not found"})
|
||||
return
|
||||
}
|
||||
deviceID := h.authenticatedDeviceID(c)
|
||||
if room.CreatorDeviceID != deviceID && room.JoinerDeviceID != deviceID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "room access denied"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": room})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) joinRoom(c *gin.Context) {
|
||||
var input struct {
|
||||
Code string `json:"code" binding:"required"`
|
||||
JoinerDeviceID string `json:"joiner_device_id" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !h.ensureAuthenticatedDevice(c, input.JoinerDeviceID) {
|
||||
return
|
||||
}
|
||||
|
||||
room, err := h.deps.RoomService.JoinRoom(input.Code, input.JoinerDeviceID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": room})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) cancelRoom(c *gin.Context) {
|
||||
var input struct {
|
||||
RequesterID string `json:"requester_id" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !h.ensureAuthenticatedDevice(c, input.RequesterID) {
|
||||
return
|
||||
}
|
||||
|
||||
room, err := h.deps.RoomService.CancelRoom(c.Param("code"), input.RequesterID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": room})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) createTransfer(c *gin.Context) {
|
||||
var input service.CreateTransferInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if !h.ensureAuthenticatedDevice(c, input.SenderDeviceID) {
|
||||
return
|
||||
}
|
||||
|
||||
transfer, err := h.deps.TransferService.Create(input)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": transfer})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) updateTransferStatus(c *gin.Context) {
|
||||
transfer, ok := h.deps.Store.GetTransfer(c.Param("id"))
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "transfer not found"})
|
||||
return
|
||||
}
|
||||
deviceID := h.authenticatedDeviceID(c)
|
||||
if transfer.SenderDeviceID != deviceID && transfer.ReceiverDeviceID != deviceID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "transfer access denied"})
|
||||
return
|
||||
}
|
||||
|
||||
var input service.UpdateTransferStatusInput
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
transfer, err := h.deps.TransferService.UpdateStatus(c.Param("id"), input)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": transfer})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) presignFallback(c *gin.Context) {
|
||||
if h.deps.MinIOClient == nil || !h.deps.MinIOClient.Enabled() {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "minio fallback is disabled"})
|
||||
return
|
||||
}
|
||||
|
||||
transfer, ok := h.deps.Store.GetTransfer(c.Param("id"))
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "transfer not found"})
|
||||
return
|
||||
}
|
||||
if transfer.SenderDeviceID != h.authenticatedDeviceID(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "transfer access denied"})
|
||||
return
|
||||
}
|
||||
|
||||
transfer, object, err := h.deps.TransferService.PrepareFallback(c.Param("id"))
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := h.ensureFallbackBucket(ctx, transfer.ID); err != nil {
|
||||
h.deps.Logger.Warn("minio ensure bucket failed", "transfer_id", transfer.ID, "error", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
uploadURL, err := h.deps.MinIOClient.PresignUpload(ctx, object.ObjectKey)
|
||||
if err != nil {
|
||||
h.deps.Logger.Warn("minio presign upload failed", "transfer_id", transfer.ID, "object_key", object.ObjectKey, "error", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
downloadURL, err := h.deps.MinIOClient.PresignDownload(ctx, object.ObjectKey)
|
||||
if err != nil {
|
||||
h.deps.Logger.Warn("minio presign download failed", "transfer_id", transfer.ID, "object_key", object.ObjectKey, "error", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": gin.H{
|
||||
"transfer": transfer,
|
||||
"upload_url": uploadURL.String(),
|
||||
"download_url": downloadURL.String(),
|
||||
"download_path": fallbackDownloadPath(transfer.ID),
|
||||
"expires_at": object.ExpiresAt,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) uploadFallback(c *gin.Context) {
|
||||
if h.deps.MinIOClient == nil || !h.deps.MinIOClient.Enabled() {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "minio fallback is disabled"})
|
||||
return
|
||||
}
|
||||
|
||||
transfer, ok := h.deps.Store.GetTransfer(c.Param("id"))
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "transfer not found"})
|
||||
return
|
||||
}
|
||||
if transfer.ObjectKey == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "fallback object is not prepared"})
|
||||
return
|
||||
}
|
||||
if transfer.SenderDeviceID != h.authenticatedDeviceID(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "transfer access denied"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
if err := h.ensureFallbackBucket(ctx, transfer.ID); err != nil {
|
||||
h.deps.Logger.Warn("minio ensure bucket failed", "transfer_id", transfer.ID, "error", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
contentLength := c.Request.ContentLength
|
||||
if contentLength <= 0 {
|
||||
contentLength = transfer.SizeBytes
|
||||
}
|
||||
|
||||
contentType := strings.TrimSpace(c.GetHeader("Content-Type"))
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
|
||||
if err := h.deps.MinIOClient.UploadObject(ctx, transfer.ObjectKey, c.Request.Body, contentLength, contentType); err != nil {
|
||||
h.deps.Logger.Warn("minio upload object failed", "transfer_id", transfer.ID, "object_key", transfer.ObjectKey, "error", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if object, ok := h.deps.Store.GetFallbackObject(transfer.ID); ok {
|
||||
object.SizeBytes = contentLength
|
||||
object.CleanupState = "ready"
|
||||
h.deps.Store.SaveFallbackObject(object)
|
||||
}
|
||||
|
||||
downloadURL, err := h.deps.MinIOClient.PresignDownload(ctx, transfer.ObjectKey)
|
||||
if err != nil {
|
||||
h.deps.Logger.Warn("minio presign download failed after upload", "transfer_id", transfer.ID, "object_key", transfer.ObjectKey, "error", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": gin.H{
|
||||
"download_url": downloadURL.String(),
|
||||
"download_path": fallbackDownloadPath(transfer.ID),
|
||||
"object_key": transfer.ObjectKey,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) pendingFallbackDownloads(c *gin.Context) {
|
||||
deviceID := strings.TrimSpace(c.Param("id"))
|
||||
if deviceID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "device id is required"})
|
||||
return
|
||||
}
|
||||
if !h.ensureAuthenticatedDevice(c, deviceID) {
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": h.deps.Store.ListPendingFallbackDownloads(deviceID, 20),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) downloadFallback(c *gin.Context) {
|
||||
if h.deps.MinIOClient == nil || !h.deps.MinIOClient.Enabled() {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "minio fallback is disabled"})
|
||||
return
|
||||
}
|
||||
|
||||
transfer, ok := h.deps.Store.GetTransfer(c.Param("id"))
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "transfer not found"})
|
||||
return
|
||||
}
|
||||
if transfer.ObjectKey == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "fallback object is not prepared"})
|
||||
return
|
||||
}
|
||||
if transfer.ReceiverDeviceID != h.authenticatedDeviceID(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "transfer access denied"})
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(c.Request.Context(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
filename := filepath.Base(transfer.Name)
|
||||
if filename == "." || filename == "" {
|
||||
filename = "download.bin"
|
||||
}
|
||||
|
||||
downloadURL, err := h.deps.MinIOClient.PresignDownloadWithFilename(ctx, transfer.ObjectKey, filename)
|
||||
if err != nil {
|
||||
h.deps.Logger.Warn("minio presign download failed", "transfer_id", transfer.ID, "object_key", transfer.ObjectKey, "error", err)
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, downloadURL.String())
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) adminLogin(c *gin.Context) {
|
||||
var input struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
session, err := h.deps.AdminService.Login(input.Username, input.Password)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": session})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) adminStats(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": gin.H{
|
||||
"stats": h.deps.Store.SnapshotStats(),
|
||||
"minio": h.deps.Store.SnapshotMinIOStorage(),
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) adminConfig(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"data": h.deps.Store.RuntimeConfig()})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) updateAdminConfig(c *gin.Context) {
|
||||
var input model.RuntimeConfig
|
||||
if err := c.ShouldBindJSON(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"data": h.deps.Store.UpdateRuntimeConfig(input)})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) recentTransfers(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"data": h.deps.Store.ListRecentTransfers(20)})
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) requireAdmin() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
header := c.GetHeader("Authorization")
|
||||
if header == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing authorization header"})
|
||||
return
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(strings.TrimPrefix(header, "Bearer"))
|
||||
if token == "" || !h.deps.AdminService.ValidateToken(token) {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid admin token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) requireDevice() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
deviceID := strings.TrimSpace(c.GetHeader("X-Device-ID"))
|
||||
token := strings.TrimSpace(c.GetHeader("X-Device-Token"))
|
||||
if deviceID == "" || token == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing device credentials"})
|
||||
return
|
||||
}
|
||||
if !h.deps.DeviceService.ValidateSession(deviceID, token) {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid device credentials"})
|
||||
return
|
||||
}
|
||||
c.Set("device_id", deviceID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) authenticatedDeviceID(c *gin.Context) string {
|
||||
value, ok := c.Get("device_id")
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
deviceID, _ := value.(string)
|
||||
return deviceID
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) ensureAuthenticatedDevice(c *gin.Context, expected string) bool {
|
||||
if h.authenticatedDeviceID(c) != strings.TrimSpace(expected) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"error": "device access denied"})
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func fallbackDownloadPath(transferID string) string {
|
||||
return "/api/transfers/" + url.PathEscape(transferID) + "/fallback/download"
|
||||
}
|
||||
|
||||
func contentDisposition(filename string) string {
|
||||
escaped := strings.ReplaceAll(filename, `"`, "")
|
||||
return `attachment; filename="` + escaped + `"`
|
||||
}
|
||||
|
||||
func (h *HTTPHandler) ensureFallbackBucket(ctx context.Context, transferID string) error {
|
||||
if h.deps.MinIOClient == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := h.deps.MinIOClient.EnsureBucket(ctx)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Some MinIO users can read/write objects in an existing bucket but cannot run
|
||||
// bucket-existence or bucket-creation checks. In that case we keep going.
|
||||
if strings.Contains(strings.ToLower(err.Error()), "access denied") {
|
||||
h.deps.Logger.Info("minio bucket ensure skipped due to limited permissions", "transfer_id", transferID)
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
195
backend/internal/handler/http_test.go
Normal file
195
backend/internal/handler/http_test.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"filefast/backend/internal/config"
|
||||
"filefast/backend/internal/model"
|
||||
"filefast/backend/internal/service"
|
||||
"filefast/backend/internal/store"
|
||||
"filefast/backend/internal/ws"
|
||||
)
|
||||
|
||||
type registeredDevice struct {
|
||||
ID string `json:"id"`
|
||||
AuthToken string `json:"auth_token"`
|
||||
}
|
||||
|
||||
type transferRecord struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
func TestProtectedRoutesRequireDeviceCredentials(t *testing.T) {
|
||||
router, _ := newTestRouter()
|
||||
|
||||
device := registerDevice(t, router, map[string]any{
|
||||
"device_id": "alpha",
|
||||
"name": "Alpha",
|
||||
"type": "desktop",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/devices/candidates?deviceId="+device.ID, nil)
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected 401 for missing device credentials, got %d", resp.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProtectedRoutesRejectMismatchedDeviceIdentity(t *testing.T) {
|
||||
router, _ := newTestRouter()
|
||||
|
||||
alpha := registerDevice(t, router, map[string]any{
|
||||
"device_id": "alpha",
|
||||
"name": "Alpha",
|
||||
"type": "desktop",
|
||||
})
|
||||
bravo := registerDevice(t, router, map[string]any{
|
||||
"device_id": "bravo",
|
||||
"name": "Bravo",
|
||||
"type": "desktop",
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/devices/candidates?deviceId="+bravo.ID, nil)
|
||||
req.Header.Set("X-Device-ID", alpha.ID)
|
||||
req.Header.Set("X-Device-Token", alpha.AuthToken)
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403 for mismatched device identity, got %d", resp.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransferStatusUpdateRequiresParticipantOwnership(t *testing.T) {
|
||||
router, _ := newTestRouter()
|
||||
|
||||
sender := registerDevice(t, router, map[string]any{
|
||||
"device_id": "sender",
|
||||
"name": "Sender",
|
||||
"type": "desktop",
|
||||
})
|
||||
receiver := registerDevice(t, router, map[string]any{
|
||||
"device_id": "receiver",
|
||||
"name": "Receiver",
|
||||
"type": "desktop",
|
||||
})
|
||||
attacker := registerDevice(t, router, map[string]any{
|
||||
"device_id": "attacker",
|
||||
"name": "Attacker",
|
||||
"type": "desktop",
|
||||
})
|
||||
|
||||
transfer := createTransfer(t, router, sender, map[string]any{
|
||||
"kind": "text",
|
||||
"name": "text-message",
|
||||
"content": "hello",
|
||||
"sender_device_id": sender.ID,
|
||||
"receiver_device_id": receiver.ID,
|
||||
})
|
||||
|
||||
body, err := json.Marshal(map[string]any{
|
||||
"final_status": "completed",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("marshal update status body: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPatch, "/api/transfers/"+transfer.ID+"/status", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Device-ID", attacker.ID)
|
||||
req.Header.Set("X-Device-Token", attacker.AuthToken)
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403 for non-participant transfer update, got %d", resp.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func newTestRouter() (http.Handler, *store.MemoryStore) {
|
||||
memStore := store.NewMemoryStore(model.RuntimeConfig{})
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
deviceService := service.NewDeviceService(memStore, nil, nil)
|
||||
deps := Dependencies{
|
||||
Config: config.Config{},
|
||||
Logger: logger,
|
||||
Store: memStore,
|
||||
DeviceService: deviceService,
|
||||
RoomService: service.NewRoomService(memStore, 0),
|
||||
TransferService: service.NewTransferService(memStore),
|
||||
AdminService: service.NewAdminService(memStore, config.AdminConfig{}, nil, nil),
|
||||
Hub: ws.NewHub(logger, deviceService, nil),
|
||||
StorageReady: true,
|
||||
RedisReady: true,
|
||||
}
|
||||
|
||||
return NewHTTPHandler(deps).Router(), memStore
|
||||
}
|
||||
|
||||
func registerDevice(t *testing.T, router http.Handler, payload map[string]any) registeredDevice {
|
||||
t.Helper()
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal register body: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/devices/register", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected register 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var payloadWrapper struct {
|
||||
Data registeredDevice `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &payloadWrapper); err != nil {
|
||||
t.Fatalf("decode register response: %v", err)
|
||||
}
|
||||
if payloadWrapper.Data.ID == "" || payloadWrapper.Data.AuthToken == "" {
|
||||
t.Fatalf("expected device registration to return id and auth token, got %+v", payloadWrapper.Data)
|
||||
}
|
||||
|
||||
return payloadWrapper.Data
|
||||
}
|
||||
|
||||
func createTransfer(t *testing.T, router http.Handler, device registeredDevice, payload map[string]any) transferRecord {
|
||||
t.Helper()
|
||||
body, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal create transfer body: %v", err)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/transfers", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Device-ID", device.ID)
|
||||
req.Header.Set("X-Device-Token", device.AuthToken)
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected create transfer 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
|
||||
var payloadWrapper struct {
|
||||
Data transferRecord `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &payloadWrapper); err != nil {
|
||||
t.Fatalf("decode transfer response: %v", err)
|
||||
}
|
||||
if payloadWrapper.Data.ID == "" {
|
||||
t.Fatalf("expected transfer id in response, got %+v", payloadWrapper.Data)
|
||||
}
|
||||
|
||||
return payloadWrapper.Data
|
||||
}
|
||||
125
backend/internal/model/types.go
Normal file
125
backend/internal/model/types.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package model
|
||||
|
||||
import "time"
|
||||
|
||||
const (
|
||||
RoomStatusWaiting = "waiting"
|
||||
RoomStatusJoined = "joined"
|
||||
RoomStatusCanceled = "canceled"
|
||||
RoomStatusExpired = "expired"
|
||||
|
||||
ChannelP2P = "p2p"
|
||||
ChannelTURN = "turn"
|
||||
ChannelMinIO = "minio"
|
||||
|
||||
TransferPending = "pending"
|
||||
TransferConnecting = "connecting"
|
||||
TransferP2PTransferring = "p2p_transferring"
|
||||
TransferTURNRelaying = "turn_relaying"
|
||||
TransferFallbackUploading = "fallback_uploading"
|
||||
TransferCompleted = "completed"
|
||||
TransferFailed = "failed"
|
||||
TransferCancelled = "cancelled"
|
||||
)
|
||||
|
||||
type RuntimeConfig struct {
|
||||
MaxMinIOFallbackSizeBytes int64 `json:"max_minio_fallback_size_bytes"`
|
||||
MinIOCapacityBytes int64 `json:"minio_capacity_bytes"`
|
||||
MinIORetentionHours int `json:"minio_retention_hours"`
|
||||
MinIOUsageAlertPercent int `json:"minio_usage_alert_percent"`
|
||||
P2PConnectTimeoutSec int `json:"p2p_connect_timeout_sec"`
|
||||
TURNConnectTimeoutSec int `json:"turn_connect_timeout_sec"`
|
||||
MinIOFallbackEnabled bool `json:"minio_fallback_enabled"`
|
||||
TURNURLs []string `json:"turn_urls"`
|
||||
TURNUsername string `json:"turn_username"`
|
||||
TURNPassword string `json:"turn_password"`
|
||||
}
|
||||
|
||||
type Device struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
UserAgent string `json:"user_agent,omitempty"`
|
||||
NetworkGroupKey string `json:"network_group_key,omitempty"`
|
||||
PublicIPHash string `json:"public_ip_hash,omitempty"`
|
||||
IsOnline bool `json:"is_online"`
|
||||
LastSeenAt time.Time `json:"last_seen_at"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type Room struct {
|
||||
Code string `json:"code"`
|
||||
CreatorDeviceID string `json:"creator_device_id"`
|
||||
JoinerDeviceID string `json:"joiner_device_id,omitempty"`
|
||||
Status string `json:"status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type Transfer struct {
|
||||
ID string `json:"id"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
Kind string `json:"kind"`
|
||||
Name string `json:"name"`
|
||||
Content string `json:"content,omitempty"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
SenderDeviceID string `json:"sender_device_id"`
|
||||
ReceiverDeviceID string `json:"receiver_device_id"`
|
||||
TransferStrategy string `json:"transfer_strategy"`
|
||||
CurrentChannel string `json:"current_channel"`
|
||||
FallbackAllowed bool `json:"fallback_allowed"`
|
||||
FinalStatus string `json:"final_status"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
FallbackReason string `json:"fallback_reason,omitempty"`
|
||||
ObjectKey string `json:"object_key,omitempty"`
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"`
|
||||
}
|
||||
|
||||
type FallbackObject struct {
|
||||
TransferID string `json:"transfer_id"`
|
||||
ObjectKey string `json:"object_key"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
CleanedAt *time.Time `json:"cleaned_at,omitempty"`
|
||||
CleanupState string `json:"cleanup_state"`
|
||||
}
|
||||
|
||||
type PendingFallbackDownload struct {
|
||||
TransferID string `json:"transfer_id"`
|
||||
Name string `json:"name"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
DownloadPath string `json:"download_path"`
|
||||
SenderDeviceID string `json:"sender_device_id"`
|
||||
}
|
||||
|
||||
type MinIOStorageOverview struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
UsedBytes int64 `json:"used_bytes"`
|
||||
CapacityBytes int64 `json:"capacity_bytes"`
|
||||
RemainingBytes int64 `json:"remaining_bytes"`
|
||||
UsagePercent int `json:"usage_percent"`
|
||||
ObjectCount int `json:"object_count"`
|
||||
}
|
||||
|
||||
type AdminSession struct {
|
||||
Token string `json:"token"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
}
|
||||
|
||||
type DeviceSession struct {
|
||||
DeviceID string `json:"device_id"`
|
||||
Token string `json:"token"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
}
|
||||
|
||||
type SignalEnvelope struct {
|
||||
Type string `json:"type"`
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
TargetDeviceID string `json:"target_device_id,omitempty"`
|
||||
Payload interface{} `json:"payload,omitempty"`
|
||||
}
|
||||
59
backend/internal/scheduler/cleanup.go
Normal file
59
backend/internal/scheduler/cleanup.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package scheduler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/storage"
|
||||
"filefast/backend/internal/store"
|
||||
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
type CleanupScheduler struct {
|
||||
logger *slog.Logger
|
||||
store *store.MemoryStore
|
||||
minioClient *storage.MinIOClient
|
||||
cron *cron.Cron
|
||||
}
|
||||
|
||||
func NewCleanupScheduler(logger *slog.Logger, store *store.MemoryStore, minioClient *storage.MinIOClient) *CleanupScheduler {
|
||||
return &CleanupScheduler{
|
||||
logger: logger,
|
||||
store: store,
|
||||
minioClient: minioClient,
|
||||
cron: cron.New(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CleanupScheduler) Start() {
|
||||
_, err := s.cron.AddFunc("@daily", s.cleanupExpiredFallbacks)
|
||||
if err != nil {
|
||||
s.logger.Error("failed to register cleanup cron", "error", err)
|
||||
return
|
||||
}
|
||||
s.cron.Start()
|
||||
}
|
||||
|
||||
func (s *CleanupScheduler) Stop() {
|
||||
ctx := s.cron.Stop()
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func (s *CleanupScheduler) cleanupExpiredFallbacks() {
|
||||
now := time.Now()
|
||||
for _, object := range s.store.ListExpiredFallbackObjects(now) {
|
||||
if s.minioClient != nil {
|
||||
if err := s.minioClient.RemoveObject(context.Background(), object.ObjectKey); err != nil {
|
||||
s.logger.Warn("failed to remove expired fallback object", "transfer_id", object.TransferID, "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
object.CleanupState = "cleaned"
|
||||
object.CleanedAt = &now
|
||||
s.store.SaveFallbackObject(object)
|
||||
s.logger.Info("cleaned expired fallback object", "transfer_id", object.TransferID, "object_key", object.ObjectKey)
|
||||
}
|
||||
}
|
||||
95
backend/internal/service/admin_service.go
Normal file
95
backend/internal/service/admin_service.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/config"
|
||||
"filefast/backend/internal/model"
|
||||
"filefast/backend/internal/store"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type AdminService struct {
|
||||
store *store.MemoryStore
|
||||
config config.AdminConfig
|
||||
sessionStore adminSessionStore
|
||||
authStore adminCredentialStore
|
||||
sessionTTL time.Duration
|
||||
}
|
||||
|
||||
type adminSessionStore interface {
|
||||
SaveAdminSession(context.Context, model.AdminSession, time.Duration) error
|
||||
HasAdminSession(context.Context, string) (bool, error)
|
||||
}
|
||||
|
||||
type adminCredentialStore interface {
|
||||
ValidateAdminCredentials(context.Context, string, string) (bool, error)
|
||||
}
|
||||
|
||||
func NewAdminService(store *store.MemoryStore, cfg config.AdminConfig, sessionStore adminSessionStore, authStore adminCredentialStore) *AdminService {
|
||||
return &AdminService{
|
||||
store: store,
|
||||
config: cfg,
|
||||
sessionStore: sessionStore,
|
||||
authStore: authStore,
|
||||
sessionTTL: 24 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AdminService) Login(username, password string) (model.AdminSession, error) {
|
||||
if s.authStore != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ok, err := s.authStore.ValidateAdminCredentials(ctx, username, password)
|
||||
if err == nil {
|
||||
if !ok {
|
||||
return model.AdminSession{}, errors.New("invalid admin credentials")
|
||||
}
|
||||
return s.issueSession()
|
||||
}
|
||||
}
|
||||
|
||||
if username != s.config.Username || password != s.config.Password {
|
||||
return model.AdminSession{}, errors.New("invalid admin credentials")
|
||||
}
|
||||
|
||||
return s.issueSession()
|
||||
}
|
||||
|
||||
func (s *AdminService) issueSession() (model.AdminSession, error) {
|
||||
session := model.AdminSession{
|
||||
Token: uuid.NewString(),
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
session = s.store.SaveAdminSession(session)
|
||||
|
||||
if s.sessionStore != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
if err := s.sessionStore.SaveAdminSession(ctx, session, s.sessionTTL); err != nil {
|
||||
return model.AdminSession{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (s *AdminService) ValidateToken(token string) bool {
|
||||
if s.store.HasAdminSession(token) {
|
||||
return true
|
||||
}
|
||||
|
||||
if s.sessionStore == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ok, err := s.sessionStore.HasAdminSession(ctx, token)
|
||||
return err == nil && ok
|
||||
}
|
||||
210
backend/internal/service/device_service.go
Normal file
210
backend/internal/service/device_service.go
Normal file
@@ -0,0 +1,210 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/model"
|
||||
"filefast/backend/internal/store"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type DeviceService struct {
|
||||
store *store.MemoryStore
|
||||
presenceStore devicePresenceStore
|
||||
sessionStore deviceSessionStore
|
||||
presenceTTL time.Duration
|
||||
sessionTTL time.Duration
|
||||
}
|
||||
|
||||
type devicePresenceStore interface {
|
||||
SetDevicePresence(context.Context, string, bool, time.Time, time.Duration) error
|
||||
GetDevicePresence(context.Context, []string) (map[string]bool, error)
|
||||
}
|
||||
|
||||
type deviceSessionStore interface {
|
||||
SaveDeviceSession(context.Context, model.DeviceSession, time.Duration) error
|
||||
ValidateDeviceSession(context.Context, string, string) (bool, error)
|
||||
}
|
||||
|
||||
type RegisterDeviceInput struct {
|
||||
DeviceID string `json:"device_id"`
|
||||
Name string `json:"name" binding:"required"`
|
||||
Type string `json:"type" binding:"required"`
|
||||
NetworkGroupKey string `json:"network_group_key"`
|
||||
PublicIPHash string `json:"public_ip_hash"`
|
||||
}
|
||||
|
||||
func NewDeviceService(store *store.MemoryStore, presenceStore devicePresenceStore, sessionStore deviceSessionStore) *DeviceService {
|
||||
return &DeviceService{
|
||||
store: store,
|
||||
presenceStore: presenceStore,
|
||||
sessionStore: sessionStore,
|
||||
presenceTTL: 45 * time.Second,
|
||||
sessionTTL: 30 * 24 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DeviceService) Register(input RegisterDeviceInput, userAgent, claimedToken string) (model.Device, model.DeviceSession) {
|
||||
now := time.Now()
|
||||
id := strings.TrimSpace(input.DeviceID)
|
||||
if id == "" {
|
||||
id = uuid.NewString()
|
||||
}
|
||||
|
||||
device, exists := s.store.GetDevice(id)
|
||||
if exists && !s.ValidateSession(id, strings.TrimSpace(claimedToken)) {
|
||||
id = uuid.NewString()
|
||||
exists = false
|
||||
}
|
||||
if !exists {
|
||||
device = model.Device{
|
||||
ID: id,
|
||||
CreatedAt: now,
|
||||
}
|
||||
}
|
||||
|
||||
device.Name = input.Name
|
||||
device.Type = input.Type
|
||||
device.UserAgent = userAgent
|
||||
device.NetworkGroupKey = input.NetworkGroupKey
|
||||
device.PublicIPHash = input.PublicIPHash
|
||||
device.LastSeenAt = now
|
||||
device.IsOnline = true
|
||||
|
||||
device = s.store.UpsertDevice(device)
|
||||
session := s.issueSession(device.ID)
|
||||
s.syncPresence(device)
|
||||
return device, session
|
||||
}
|
||||
|
||||
func (s *DeviceService) Heartbeat(deviceID string) (model.Device, bool) {
|
||||
device, ok := s.store.GetDevice(deviceID)
|
||||
if !ok {
|
||||
return model.Device{}, false
|
||||
}
|
||||
|
||||
device.LastSeenAt = time.Now()
|
||||
device.IsOnline = true
|
||||
device = s.store.UpsertDevice(device)
|
||||
s.syncPresence(device)
|
||||
return device, true
|
||||
}
|
||||
|
||||
func (s *DeviceService) SetOnline(deviceID string, online bool) (model.Device, bool) {
|
||||
device, ok := s.store.GetDevice(deviceID)
|
||||
if !ok {
|
||||
return model.Device{}, false
|
||||
}
|
||||
|
||||
device.IsOnline = online
|
||||
device.LastSeenAt = time.Now()
|
||||
device = s.store.UpsertDevice(device)
|
||||
s.syncPresence(device)
|
||||
return device, true
|
||||
}
|
||||
|
||||
func (s *DeviceService) ListCandidates(currentDeviceID string) []model.Device {
|
||||
current, _ := s.store.GetDevice(currentDeviceID)
|
||||
devices := s.store.ListDevices()
|
||||
s.applyPresence(devices)
|
||||
candidates := make([]model.Device, 0, len(devices))
|
||||
|
||||
for _, device := range devices {
|
||||
if device.ID == currentDeviceID || !device.IsOnline {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, device)
|
||||
}
|
||||
|
||||
sort.SliceStable(candidates, func(i, j int) bool {
|
||||
leftSameNetwork := current.NetworkGroupKey != "" && candidates[i].NetworkGroupKey == current.NetworkGroupKey
|
||||
rightSameNetwork := current.NetworkGroupKey != "" && candidates[j].NetworkGroupKey == current.NetworkGroupKey
|
||||
if leftSameNetwork != rightSameNetwork {
|
||||
return leftSameNetwork
|
||||
}
|
||||
return candidates[i].LastSeenAt.After(candidates[j].LastSeenAt)
|
||||
})
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
func (s *DeviceService) ValidateSession(deviceID, token string) bool {
|
||||
deviceID = strings.TrimSpace(deviceID)
|
||||
token = strings.TrimSpace(token)
|
||||
if deviceID == "" || token == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.store.ValidateDeviceSession(deviceID, token) {
|
||||
return true
|
||||
}
|
||||
|
||||
if s.sessionStore == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
ok, err := s.sessionStore.ValidateDeviceSession(ctx, deviceID, token)
|
||||
return err == nil && ok
|
||||
}
|
||||
|
||||
func (s *DeviceService) issueSession(deviceID string) model.DeviceSession {
|
||||
session := model.DeviceSession{
|
||||
DeviceID: deviceID,
|
||||
Token: uuid.NewString(),
|
||||
CreatedAt: time.Now(),
|
||||
ExpiresAt: time.Now().Add(s.sessionTTL),
|
||||
}
|
||||
session = s.store.SaveDeviceSession(session)
|
||||
|
||||
if s.sessionStore != nil {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.sessionStore.SaveDeviceSession(ctx, session, s.sessionTTL)
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *DeviceService) syncPresence(device model.Device) {
|
||||
if s.presenceStore == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
_ = s.presenceStore.SetDevicePresence(ctx, device.ID, device.IsOnline, device.LastSeenAt, s.presenceTTL)
|
||||
}
|
||||
|
||||
func (s *DeviceService) applyPresence(devices []model.Device) {
|
||||
if s.presenceStore == nil || len(devices) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
deviceIDs := make([]string, 0, len(devices))
|
||||
for _, device := range devices {
|
||||
if device.ID != "" {
|
||||
deviceIDs = append(deviceIDs, device.ID)
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
statuses, err := s.presenceStore.GetDevicePresence(ctx, deviceIDs)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for index := range devices {
|
||||
if online, ok := statuses[devices[index].ID]; ok {
|
||||
devices[index].IsOnline = online
|
||||
}
|
||||
}
|
||||
}
|
||||
58
backend/internal/service/device_service_test.go
Normal file
58
backend/internal/service/device_service_test.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"filefast/backend/internal/model"
|
||||
"filefast/backend/internal/store"
|
||||
)
|
||||
|
||||
func TestRegisterReusesKnownDeviceOnlyWithValidToken(t *testing.T) {
|
||||
memStore := store.NewMemoryStore(model.RuntimeConfig{})
|
||||
deviceService := NewDeviceService(memStore, nil, nil)
|
||||
|
||||
device, session := deviceService.Register(RegisterDeviceInput{
|
||||
DeviceID: "known-device",
|
||||
Name: "Alpha",
|
||||
Type: "desktop",
|
||||
}, "ua/1.0", "")
|
||||
|
||||
if device.ID != "known-device" {
|
||||
t.Fatalf("expected first registration to keep requested device id, got %q", device.ID)
|
||||
}
|
||||
if !deviceService.ValidateSession(device.ID, session.Token) {
|
||||
t.Fatal("expected issued device token to validate")
|
||||
}
|
||||
|
||||
hijacked, hijackedSession := deviceService.Register(RegisterDeviceInput{
|
||||
DeviceID: "known-device",
|
||||
Name: "Mallory",
|
||||
Type: "desktop",
|
||||
}, "ua/1.0", "")
|
||||
|
||||
if hijacked.ID == device.ID {
|
||||
t.Fatal("expected registration without token to receive a new device id")
|
||||
}
|
||||
if !deviceService.ValidateSession(hijacked.ID, hijackedSession.Token) {
|
||||
t.Fatal("expected replacement device token to validate")
|
||||
}
|
||||
|
||||
restored, rotatedSession := deviceService.Register(RegisterDeviceInput{
|
||||
DeviceID: "known-device",
|
||||
Name: "Alpha",
|
||||
Type: "desktop",
|
||||
}, "ua/1.0", session.Token)
|
||||
|
||||
if restored.ID != device.ID {
|
||||
t.Fatalf("expected valid token to reclaim original device id, got %q", restored.ID)
|
||||
}
|
||||
if rotatedSession.Token == session.Token {
|
||||
t.Fatal("expected registration to rotate the device token")
|
||||
}
|
||||
if deviceService.ValidateSession(restored.ID, session.Token) {
|
||||
t.Fatal("expected rotated token to invalidate the old token")
|
||||
}
|
||||
if !deviceService.ValidateSession(restored.ID, rotatedSession.Token) {
|
||||
t.Fatal("expected rotated device token to validate")
|
||||
}
|
||||
}
|
||||
84
backend/internal/service/room_service.go
Normal file
84
backend/internal/service/room_service.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/model"
|
||||
"filefast/backend/internal/store"
|
||||
)
|
||||
|
||||
type RoomService struct {
|
||||
store *store.MemoryStore
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewRoomService(store *store.MemoryStore, ttl time.Duration) *RoomService {
|
||||
return &RoomService{
|
||||
store: store,
|
||||
ttl: ttl,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *RoomService) CreateRoom(creatorDeviceID string) (model.Room, error) {
|
||||
if creatorDeviceID == "" {
|
||||
return model.Room{}, errors.New("creator_device_id is required")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for range 20 {
|
||||
code := fmt.Sprintf("%04d", rand.IntN(10000))
|
||||
if room, ok := s.store.GetRoom(code); ok && room.ExpiresAt.After(now) && room.Status != model.RoomStatusExpired {
|
||||
continue
|
||||
}
|
||||
|
||||
room := model.Room{
|
||||
Code: code,
|
||||
CreatorDeviceID: creatorDeviceID,
|
||||
Status: model.RoomStatusWaiting,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
return s.store.UpsertRoom(room), nil
|
||||
}
|
||||
|
||||
return model.Room{}, errors.New("failed to allocate room code")
|
||||
}
|
||||
|
||||
func (s *RoomService) JoinRoom(code, joinerDeviceID string) (model.Room, error) {
|
||||
room, ok := s.store.GetRoom(code)
|
||||
if !ok {
|
||||
return model.Room{}, errors.New("room not found")
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if !room.ExpiresAt.After(now) {
|
||||
room.Status = model.RoomStatusExpired
|
||||
s.store.UpsertRoom(room)
|
||||
return model.Room{}, errors.New("room expired")
|
||||
}
|
||||
|
||||
if room.Status != model.RoomStatusWaiting {
|
||||
return model.Room{}, errors.New("room unavailable")
|
||||
}
|
||||
|
||||
room.JoinerDeviceID = joinerDeviceID
|
||||
room.Status = model.RoomStatusJoined
|
||||
return s.store.UpsertRoom(room), nil
|
||||
}
|
||||
|
||||
func (s *RoomService) CancelRoom(code, requesterID string) (model.Room, error) {
|
||||
room, ok := s.store.GetRoom(code)
|
||||
if !ok {
|
||||
return model.Room{}, errors.New("room not found")
|
||||
}
|
||||
|
||||
if room.CreatorDeviceID != requesterID {
|
||||
return model.Room{}, errors.New("only creator can cancel room")
|
||||
}
|
||||
|
||||
room.Status = model.RoomStatusCanceled
|
||||
return s.store.UpsertRoom(room), nil
|
||||
}
|
||||
138
backend/internal/service/transfer_service.go
Normal file
138
backend/internal/service/transfer_service.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/model"
|
||||
"filefast/backend/internal/store"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type TransferService struct {
|
||||
store *store.MemoryStore
|
||||
}
|
||||
|
||||
type CreateTransferInput struct {
|
||||
SessionID string `json:"session_id"`
|
||||
Kind string `json:"kind" binding:"required"`
|
||||
Name string `json:"name"`
|
||||
Content string `json:"content"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
SenderDeviceID string `json:"sender_device_id" binding:"required"`
|
||||
ReceiverDeviceID string `json:"receiver_device_id" binding:"required"`
|
||||
}
|
||||
|
||||
type UpdateTransferStatusInput struct {
|
||||
CurrentChannel string `json:"current_channel"`
|
||||
FinalStatus string `json:"final_status"`
|
||||
FallbackReason string `json:"fallback_reason"`
|
||||
}
|
||||
|
||||
func NewTransferService(store *store.MemoryStore) *TransferService {
|
||||
return &TransferService{store: store}
|
||||
}
|
||||
|
||||
func (s *TransferService) Create(input CreateTransferInput) (model.Transfer, error) {
|
||||
switch input.Kind {
|
||||
case "file":
|
||||
if input.Name == "" {
|
||||
return model.Transfer{}, errors.New("file name is required")
|
||||
}
|
||||
if input.SizeBytes <= 0 {
|
||||
return model.Transfer{}, errors.New("file size must be greater than zero")
|
||||
}
|
||||
case "text":
|
||||
if input.Content == "" {
|
||||
return model.Transfer{}, errors.New("text content is required")
|
||||
}
|
||||
if input.Name == "" {
|
||||
input.Name = "text-message"
|
||||
}
|
||||
default:
|
||||
return model.Transfer{}, errors.New("unsupported transfer kind")
|
||||
}
|
||||
|
||||
runtime := s.store.RuntimeConfig()
|
||||
now := time.Now()
|
||||
fallbackAllowed := input.Kind == "file" && runtime.MinIOFallbackEnabled
|
||||
strategy := "p2p_turn"
|
||||
if fallbackAllowed {
|
||||
strategy = "p2p_turn_minio"
|
||||
}
|
||||
|
||||
transfer := model.Transfer{
|
||||
ID: uuid.NewString(),
|
||||
SessionID: input.SessionID,
|
||||
Kind: input.Kind,
|
||||
Name: input.Name,
|
||||
Content: input.Content,
|
||||
SizeBytes: input.SizeBytes,
|
||||
SenderDeviceID: input.SenderDeviceID,
|
||||
ReceiverDeviceID: input.ReceiverDeviceID,
|
||||
TransferStrategy: strategy,
|
||||
CurrentChannel: model.ChannelP2P,
|
||||
FallbackAllowed: fallbackAllowed,
|
||||
FinalStatus: model.TransferPending,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
return s.store.UpsertTransfer(transfer), nil
|
||||
}
|
||||
|
||||
func (s *TransferService) UpdateStatus(transferID string, input UpdateTransferStatusInput) (model.Transfer, error) {
|
||||
transfer, ok := s.store.GetTransfer(transferID)
|
||||
if !ok {
|
||||
return model.Transfer{}, errors.New("transfer not found")
|
||||
}
|
||||
|
||||
if input.CurrentChannel != "" {
|
||||
transfer.CurrentChannel = input.CurrentChannel
|
||||
}
|
||||
if input.FinalStatus != "" {
|
||||
transfer.FinalStatus = input.FinalStatus
|
||||
}
|
||||
if input.FallbackReason != "" {
|
||||
transfer.FallbackReason = input.FallbackReason
|
||||
}
|
||||
|
||||
transfer.UpdatedAt = time.Now()
|
||||
return s.store.UpsertTransfer(transfer), nil
|
||||
}
|
||||
|
||||
func (s *TransferService) PrepareFallback(transferID string) (model.Transfer, model.FallbackObject, error) {
|
||||
transfer, ok := s.store.GetTransfer(transferID)
|
||||
if !ok {
|
||||
return model.Transfer{}, model.FallbackObject{}, errors.New("transfer not found")
|
||||
}
|
||||
if !transfer.FallbackAllowed {
|
||||
return model.Transfer{}, model.FallbackObject{}, errors.New("transfer cannot use minio fallback")
|
||||
}
|
||||
if object, ok := s.store.GetFallbackObject(transfer.ID); ok && object.CleanedAt == nil && object.ExpiresAt.After(time.Now()) {
|
||||
return transfer, object, nil
|
||||
}
|
||||
|
||||
runtime := s.store.RuntimeConfig()
|
||||
now := time.Now()
|
||||
expireAt := now.Add(time.Duration(runtime.MinIORetentionHours) * time.Hour)
|
||||
objectKey := fmt.Sprintf("fallback/%s/%d-%s", now.Format("20060102"), now.Unix(), transfer.ID)
|
||||
|
||||
transfer.ObjectKey = objectKey
|
||||
transfer.ExpiresAt = &expireAt
|
||||
transfer.UpdatedAt = now
|
||||
transfer = s.store.UpsertTransfer(transfer)
|
||||
|
||||
object := model.FallbackObject{
|
||||
TransferID: transfer.ID,
|
||||
ObjectKey: objectKey,
|
||||
SizeBytes: transfer.SizeBytes,
|
||||
CreatedAt: now,
|
||||
ExpiresAt: expireAt,
|
||||
CleanupState: "uploading",
|
||||
}
|
||||
object = s.store.SaveFallbackObject(object)
|
||||
return transfer, object, nil
|
||||
}
|
||||
141
backend/internal/storage/minio.go
Normal file
141
backend/internal/storage/minio.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/config"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
)
|
||||
|
||||
type MinIOClient struct {
|
||||
client *minio.Client
|
||||
bucket string
|
||||
presignExpiry time.Duration
|
||||
enabled bool
|
||||
}
|
||||
|
||||
func NewMinIOClient(cfg config.MinIOConfig) (*MinIOClient, error) {
|
||||
if cfg.Endpoint == "" || cfg.AccessKey == "" || cfg.SecretKey == "" {
|
||||
return &MinIOClient{
|
||||
bucket: cfg.Bucket,
|
||||
presignExpiry: cfg.PresignExpiry,
|
||||
enabled: false,
|
||||
}, nil
|
||||
}
|
||||
|
||||
client, err := minio.New(cfg.Endpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
|
||||
Secure: cfg.UseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &MinIOClient{
|
||||
client: client,
|
||||
bucket: cfg.Bucket,
|
||||
presignExpiry: cfg.PresignExpiry,
|
||||
enabled: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *MinIOClient) Enabled() bool {
|
||||
return c != nil && c.enabled
|
||||
}
|
||||
|
||||
func (c *MinIOClient) EnsureBucket(ctx context.Context) error {
|
||||
if !c.Enabled() {
|
||||
return nil
|
||||
}
|
||||
|
||||
exists, err := c.client.BucketExists(ctx, c.bucket)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return c.client.MakeBucket(ctx, c.bucket, minio.MakeBucketOptions{})
|
||||
}
|
||||
|
||||
func (c *MinIOClient) PresignUpload(ctx context.Context, objectKey string) (*url.URL, error) {
|
||||
if !c.Enabled() {
|
||||
return nil, errors.New("minio is disabled")
|
||||
}
|
||||
return c.client.PresignedPutObject(ctx, c.bucket, objectKey, c.presignExpiry)
|
||||
}
|
||||
|
||||
func (c *MinIOClient) PresignDownload(ctx context.Context, objectKey string) (*url.URL, error) {
|
||||
return c.PresignDownloadWithFilename(ctx, objectKey, "")
|
||||
}
|
||||
|
||||
func (c *MinIOClient) PresignDownloadWithFilename(ctx context.Context, objectKey, filename string) (*url.URL, error) {
|
||||
if !c.Enabled() {
|
||||
return nil, errors.New("minio is disabled")
|
||||
}
|
||||
|
||||
var reqParams url.Values
|
||||
if filename = sanitizeDownloadFilename(filename); filename != "" {
|
||||
reqParams = make(url.Values)
|
||||
reqParams.Set("response-content-disposition", `attachment; filename="`+filename+`"`)
|
||||
}
|
||||
|
||||
return c.client.PresignedGetObject(ctx, c.bucket, objectKey, c.presignExpiry, reqParams)
|
||||
}
|
||||
|
||||
func (c *MinIOClient) UploadObject(ctx context.Context, objectKey string, reader io.Reader, size int64, contentType string) error {
|
||||
if !c.Enabled() {
|
||||
return errors.New("minio is disabled")
|
||||
}
|
||||
|
||||
options := minio.PutObjectOptions{
|
||||
ContentType: contentType,
|
||||
}
|
||||
_, err := c.client.PutObject(ctx, c.bucket, objectKey, reader, size, options)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *MinIOClient) OpenObject(ctx context.Context, objectKey string) (*minio.Object, minio.ObjectInfo, error) {
|
||||
if !c.Enabled() {
|
||||
return nil, minio.ObjectInfo{}, errors.New("minio is disabled")
|
||||
}
|
||||
|
||||
object, err := c.client.GetObject(ctx, c.bucket, objectKey, minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
return nil, minio.ObjectInfo{}, err
|
||||
}
|
||||
|
||||
info, err := object.Stat()
|
||||
if err != nil {
|
||||
_ = object.Close()
|
||||
return nil, minio.ObjectInfo{}, err
|
||||
}
|
||||
|
||||
return object, info, nil
|
||||
}
|
||||
|
||||
func (c *MinIOClient) RemoveObject(ctx context.Context, objectKey string) error {
|
||||
if !c.Enabled() {
|
||||
return nil
|
||||
}
|
||||
return c.client.RemoveObject(ctx, c.bucket, objectKey, minio.RemoveObjectOptions{})
|
||||
}
|
||||
|
||||
func sanitizeDownloadFilename(filename string) string {
|
||||
filename = strings.TrimSpace(filename)
|
||||
if filename == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
filename = strings.ReplaceAll(filename, `"`, "")
|
||||
filename = strings.ReplaceAll(filename, "\r", "")
|
||||
filename = strings.ReplaceAll(filename, "\n", "")
|
||||
return filename
|
||||
}
|
||||
205
backend/internal/storage/redis.go
Normal file
205
backend/internal/storage/redis.go
Normal file
@@ -0,0 +1,205 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/config"
|
||||
"filefast/backend/internal/model"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
RedisRealtimeRelayChannel = "filefast:ws:relay"
|
||||
RedisRealtimePresenceChannel = "filefast:ws:presence"
|
||||
)
|
||||
|
||||
type RedisClient struct {
|
||||
client *redis.Client
|
||||
}
|
||||
|
||||
type RealtimeMessage struct {
|
||||
Channel string
|
||||
Payload []byte
|
||||
}
|
||||
|
||||
func NewRedisClient(cfg config.RedisConfig) *RedisClient {
|
||||
client := redis.NewClient(&redis.Options{
|
||||
Addr: cfg.Addr,
|
||||
Password: cfg.Password,
|
||||
DB: cfg.DB,
|
||||
MaxRetries: 1,
|
||||
DialTimeout: 3 * time.Second,
|
||||
ReadTimeout: 3 * time.Second,
|
||||
WriteTimeout: 3 * time.Second,
|
||||
})
|
||||
|
||||
return &RedisClient{client: client}
|
||||
}
|
||||
|
||||
func (c *RedisClient) Ping(ctx context.Context) error {
|
||||
if !c.available() {
|
||||
return redis.ErrClosed
|
||||
}
|
||||
return c.client.Ping(ctx).Err()
|
||||
}
|
||||
|
||||
func (c *RedisClient) SaveAdminSession(ctx context.Context, session model.AdminSession, ttl time.Duration) error {
|
||||
if !c.available() {
|
||||
return redis.ErrClosed
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.client.Set(ctx, adminSessionKey(session.Token), payload, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *RedisClient) HasAdminSession(ctx context.Context, token string) (bool, error) {
|
||||
if !c.available() {
|
||||
return false, redis.ErrClosed
|
||||
}
|
||||
|
||||
result, err := c.client.Exists(ctx, adminSessionKey(token)).Result()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return result > 0, nil
|
||||
}
|
||||
|
||||
func (c *RedisClient) SaveDeviceSession(ctx context.Context, session model.DeviceSession, ttl time.Duration) error {
|
||||
if !c.available() {
|
||||
return redis.ErrClosed
|
||||
}
|
||||
|
||||
return c.client.Set(ctx, deviceSessionKey(session.DeviceID), session.Token, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *RedisClient) ValidateDeviceSession(ctx context.Context, deviceID, token string) (bool, error) {
|
||||
if !c.available() {
|
||||
return false, redis.ErrClosed
|
||||
}
|
||||
|
||||
value, err := c.client.Get(ctx, deviceSessionKey(deviceID)).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return value == token, nil
|
||||
}
|
||||
|
||||
func (c *RedisClient) SetDevicePresence(ctx context.Context, deviceID string, online bool, lastSeen time.Time, ttl time.Duration) error {
|
||||
if !c.available() {
|
||||
return redis.ErrClosed
|
||||
}
|
||||
|
||||
key := devicePresenceKey(deviceID)
|
||||
if !online {
|
||||
return c.client.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
return c.client.Set(ctx, key, lastSeen.UTC().Format(time.RFC3339Nano), ttl).Err()
|
||||
}
|
||||
|
||||
func (c *RedisClient) GetDevicePresence(ctx context.Context, deviceIDs []string) (map[string]bool, error) {
|
||||
if !c.available() {
|
||||
return nil, redis.ErrClosed
|
||||
}
|
||||
|
||||
statuses := make(map[string]bool, len(deviceIDs))
|
||||
if len(deviceIDs) == 0 {
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(deviceIDs))
|
||||
for _, id := range deviceIDs {
|
||||
keys = append(keys, devicePresenceKey(id))
|
||||
}
|
||||
|
||||
values, err := c.client.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for index, id := range deviceIDs {
|
||||
statuses[id] = values[index] != nil
|
||||
}
|
||||
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
func (c *RedisClient) PublishRealtime(ctx context.Context, channel string, payload []byte) error {
|
||||
if !c.available() {
|
||||
return redis.ErrClosed
|
||||
}
|
||||
return c.client.Publish(ctx, channel, payload).Err()
|
||||
}
|
||||
|
||||
func (c *RedisClient) SubscribeRealtime(ctx context.Context, channels ...string) (<-chan RealtimeMessage, error) {
|
||||
if !c.available() {
|
||||
return nil, redis.ErrClosed
|
||||
}
|
||||
|
||||
pubsub := c.client.Subscribe(ctx, channels...)
|
||||
if _, err := pubsub.Receive(ctx); err != nil {
|
||||
_ = pubsub.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
source := pubsub.Channel()
|
||||
messages := make(chan RealtimeMessage, 64)
|
||||
|
||||
go func() {
|
||||
defer close(messages)
|
||||
defer pubsub.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case message, ok := <-source:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
messages <- RealtimeMessage{
|
||||
Channel: message.Channel,
|
||||
Payload: []byte(message.Payload),
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return messages, nil
|
||||
}
|
||||
|
||||
func (c *RedisClient) Close() error {
|
||||
if !c.available() {
|
||||
return nil
|
||||
}
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
func (c *RedisClient) available() bool {
|
||||
return c != nil && c.client != nil
|
||||
}
|
||||
|
||||
func adminSessionKey(token string) string {
|
||||
return fmt.Sprintf("filefast:admin:session:%s", token)
|
||||
}
|
||||
|
||||
func devicePresenceKey(deviceID string) string {
|
||||
return fmt.Sprintf("filefast:device:online:%s", deviceID)
|
||||
}
|
||||
|
||||
func deviceSessionKey(deviceID string) string {
|
||||
return fmt.Sprintf("filefast:device:session:%s", deviceID)
|
||||
}
|
||||
125
backend/internal/storage/shared.go
Normal file
125
backend/internal/storage/shared.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"filefast/backend/internal/model"
|
||||
)
|
||||
|
||||
const (
|
||||
runtimeConfigKey = "transfer_policy"
|
||||
)
|
||||
|
||||
type runtimeConfigPayload struct {
|
||||
MaxMinIOFallbackSizeBytes *int64 `json:"max_minio_fallback_size_bytes"`
|
||||
MaxMinIOFallbackGB *int64 `json:"max_minio_fallback_gb"`
|
||||
MinIOCapacityBytes *int64 `json:"minio_capacity_bytes"`
|
||||
MinIOCapacityGB *int64 `json:"minio_capacity_gb"`
|
||||
MinIORetentionHours *int `json:"minio_retention_hours"`
|
||||
MinIOUsageAlertPercent *int `json:"minio_usage_alert_percent"`
|
||||
P2PConnectTimeoutSec *int `json:"p2p_connect_timeout_sec"`
|
||||
TURNConnectTimeoutSec *int `json:"turn_connect_timeout_sec"`
|
||||
MinIOFallbackEnabled *bool `json:"minio_fallback_enabled"`
|
||||
TURNURLs []string `json:"turn_urls"`
|
||||
TURNUsername *string `json:"turn_username"`
|
||||
TURNPassword *string `json:"turn_password"`
|
||||
}
|
||||
|
||||
func decodeRuntimeConfig(raw []byte, fallback model.RuntimeConfig) (model.RuntimeConfig, error) {
|
||||
cfg := fallback
|
||||
if len(raw) == 0 {
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
var payload runtimeConfigPayload
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
return model.RuntimeConfig{}, err
|
||||
}
|
||||
|
||||
if payload.MaxMinIOFallbackSizeBytes != nil {
|
||||
cfg.MaxMinIOFallbackSizeBytes = *payload.MaxMinIOFallbackSizeBytes
|
||||
} else if payload.MaxMinIOFallbackGB != nil {
|
||||
cfg.MaxMinIOFallbackSizeBytes = *payload.MaxMinIOFallbackGB * 1024 * 1024 * 1024
|
||||
}
|
||||
if payload.MinIOCapacityBytes != nil {
|
||||
cfg.MinIOCapacityBytes = *payload.MinIOCapacityBytes
|
||||
} else if payload.MinIOCapacityGB != nil {
|
||||
cfg.MinIOCapacityBytes = *payload.MinIOCapacityGB * 1024 * 1024 * 1024
|
||||
}
|
||||
if payload.MinIORetentionHours != nil {
|
||||
cfg.MinIORetentionHours = *payload.MinIORetentionHours
|
||||
}
|
||||
if payload.MinIOUsageAlertPercent != nil {
|
||||
cfg.MinIOUsageAlertPercent = *payload.MinIOUsageAlertPercent
|
||||
}
|
||||
if payload.P2PConnectTimeoutSec != nil {
|
||||
cfg.P2PConnectTimeoutSec = *payload.P2PConnectTimeoutSec
|
||||
}
|
||||
if payload.TURNConnectTimeoutSec != nil {
|
||||
cfg.TURNConnectTimeoutSec = *payload.TURNConnectTimeoutSec
|
||||
}
|
||||
if payload.MinIOFallbackEnabled != nil {
|
||||
cfg.MinIOFallbackEnabled = *payload.MinIOFallbackEnabled
|
||||
}
|
||||
if payload.TURNURLs != nil {
|
||||
cfg.TURNURLs = payload.TURNURLs
|
||||
}
|
||||
if payload.TURNUsername != nil {
|
||||
cfg.TURNUsername = *payload.TURNUsername
|
||||
}
|
||||
if payload.TURNPassword != nil {
|
||||
cfg.TURNPassword = *payload.TURNPassword
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func readFirstExistingFile(paths ...string) ([]byte, error) {
|
||||
candidates := make([]string, 0, len(paths)*8)
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
cwd, err := os.Getwd()
|
||||
if err == nil {
|
||||
current := cwd
|
||||
for {
|
||||
for _, path := range paths {
|
||||
candidate := filepath.Clean(filepath.Join(current, path))
|
||||
if _, ok := seen[candidate]; ok {
|
||||
continue
|
||||
}
|
||||
seen[candidate] = struct{}{}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
|
||||
parent := filepath.Dir(current)
|
||||
if parent == current {
|
||||
break
|
||||
}
|
||||
current = parent
|
||||
}
|
||||
}
|
||||
|
||||
for _, path := range paths {
|
||||
candidate := filepath.Clean(path)
|
||||
if _, ok := seen[candidate]; ok {
|
||||
continue
|
||||
}
|
||||
seen[candidate] = struct{}{}
|
||||
candidates = append(candidates, candidate)
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
data, err := os.ReadFile(candidate)
|
||||
if err == nil {
|
||||
return data, nil
|
||||
}
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, os.ErrNotExist
|
||||
}
|
||||
617
backend/internal/storage/sqlite.go
Normal file
617
backend/internal/storage/sqlite.go
Normal file
@@ -0,0 +1,617 @@
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/config"
|
||||
"filefast/backend/internal/model"
|
||||
"filefast/backend/internal/store"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
type SQLiteClient struct {
|
||||
db *sql.DB
|
||||
path string
|
||||
}
|
||||
|
||||
func NewSQLiteClient(cfg config.SQLiteConfig) (*SQLiteClient, error) {
|
||||
path := strings.TrimSpace(cfg.Path)
|
||||
if path == "" {
|
||||
return nil, errors.New("sqlite path is required")
|
||||
}
|
||||
|
||||
if dir := filepath.Dir(path); dir != "" && dir != "." {
|
||||
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite", path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
db.SetMaxOpenConns(1)
|
||||
db.SetMaxIdleConns(1)
|
||||
db.SetConnMaxLifetime(0)
|
||||
|
||||
client := &SQLiteClient{
|
||||
db: db,
|
||||
path: path,
|
||||
}
|
||||
|
||||
for _, pragma := range []string{
|
||||
"PRAGMA journal_mode = WAL",
|
||||
"PRAGMA foreign_keys = ON",
|
||||
"PRAGMA busy_timeout = 5000",
|
||||
} {
|
||||
if _, err := db.Exec(pragma); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) Ping(ctx context.Context) error {
|
||||
return c.db.PingContext(ctx)
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) EnsureSchema(ctx context.Context) error {
|
||||
sqlBytes, err := readFirstExistingFile(
|
||||
filepath.Join("sql", "init_sqlite.sql"),
|
||||
filepath.Join("backend", "sql", "init_sqlite.sql"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = c.db.ExecContext(ctx, string(sqlBytes))
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) ResetOnlineDevices(ctx context.Context) error {
|
||||
_, err := c.db.ExecContext(ctx, `UPDATE devices SET is_online = 0 WHERE is_online = 1`)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) ResetOperationalData(ctx context.Context) error {
|
||||
tx, err := c.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
for _, statement := range []string{
|
||||
`DELETE FROM fallback_objects`,
|
||||
`DELETE FROM transfers`,
|
||||
`DELETE FROM sessions`,
|
||||
`DELETE FROM rooms`,
|
||||
`DELETE FROM devices`,
|
||||
} {
|
||||
if _, err := tx.ExecContext(ctx, statement); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) LoadSnapshot(ctx context.Context, fallbackRuntime model.RuntimeConfig) (store.Snapshot, error) {
|
||||
snapshot := store.Snapshot{
|
||||
Runtime: &fallbackRuntime,
|
||||
}
|
||||
|
||||
devices, err := c.loadDevices(ctx)
|
||||
if err != nil {
|
||||
return store.Snapshot{}, err
|
||||
}
|
||||
snapshot.Devices = devices
|
||||
|
||||
rooms, err := c.loadRooms(ctx)
|
||||
if err != nil {
|
||||
return store.Snapshot{}, err
|
||||
}
|
||||
snapshot.Rooms = rooms
|
||||
|
||||
transfers, err := c.loadTransfers(ctx)
|
||||
if err != nil {
|
||||
return store.Snapshot{}, err
|
||||
}
|
||||
snapshot.Transfers = transfers
|
||||
|
||||
fallbackObjects, err := c.loadFallbackObjects(ctx)
|
||||
if err != nil {
|
||||
return store.Snapshot{}, err
|
||||
}
|
||||
snapshot.FallbackObjects = fallbackObjects
|
||||
|
||||
runtimeCfg, err := c.loadRuntimeConfig(ctx, fallbackRuntime)
|
||||
if err != nil {
|
||||
return store.Snapshot{}, err
|
||||
}
|
||||
snapshot.Runtime = &runtimeCfg
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) PersistDevice(ctx context.Context, device model.Device) error {
|
||||
_, err := c.db.ExecContext(ctx, `
|
||||
INSERT INTO devices (
|
||||
id, device_code, name, type, user_agent, network_group_key, public_ip_hash,
|
||||
is_online, last_seen_at, created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
device_code = excluded.device_code,
|
||||
name = excluded.name,
|
||||
type = excluded.type,
|
||||
user_agent = excluded.user_agent,
|
||||
network_group_key = excluded.network_group_key,
|
||||
public_ip_hash = excluded.public_ip_hash,
|
||||
is_online = excluded.is_online,
|
||||
last_seen_at = excluded.last_seen_at,
|
||||
updated_at = excluded.updated_at
|
||||
`,
|
||||
device.ID,
|
||||
device.ID,
|
||||
device.Name,
|
||||
device.Type,
|
||||
nullableString(device.UserAgent),
|
||||
nullableString(device.NetworkGroupKey),
|
||||
nullableString(device.PublicIPHash),
|
||||
boolToInt(device.IsOnline),
|
||||
encodeTime(device.LastSeenAt),
|
||||
encodeTime(device.CreatedAt),
|
||||
encodeTime(time.Now()),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) PersistRoom(ctx context.Context, room model.Room) error {
|
||||
_, err := c.db.ExecContext(ctx, `
|
||||
INSERT INTO rooms (
|
||||
code, creator_device_id, joiner_device_id, status, created_at, expires_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(code) DO UPDATE SET
|
||||
creator_device_id = excluded.creator_device_id,
|
||||
joiner_device_id = excluded.joiner_device_id,
|
||||
status = excluded.status,
|
||||
expires_at = excluded.expires_at,
|
||||
updated_at = excluded.updated_at
|
||||
`,
|
||||
room.Code,
|
||||
room.CreatorDeviceID,
|
||||
nullableString(room.JoinerDeviceID),
|
||||
room.Status,
|
||||
encodeTime(room.CreatedAt),
|
||||
encodeTime(room.ExpiresAt),
|
||||
encodeTime(time.Now()),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) PersistTransfer(ctx context.Context, transfer model.Transfer) error {
|
||||
_, err := c.db.ExecContext(ctx, `
|
||||
INSERT INTO transfers (
|
||||
id, session_id, kind, name, content, size_bytes,
|
||||
sender_device_id, receiver_device_id, transfer_strategy, current_channel,
|
||||
fallback_allowed, final_status, fallback_reason, object_key, expires_at,
|
||||
created_at, updated_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
session_id = excluded.session_id,
|
||||
kind = excluded.kind,
|
||||
name = excluded.name,
|
||||
content = excluded.content,
|
||||
size_bytes = excluded.size_bytes,
|
||||
sender_device_id = excluded.sender_device_id,
|
||||
receiver_device_id = excluded.receiver_device_id,
|
||||
transfer_strategy = excluded.transfer_strategy,
|
||||
current_channel = excluded.current_channel,
|
||||
fallback_allowed = excluded.fallback_allowed,
|
||||
final_status = excluded.final_status,
|
||||
fallback_reason = excluded.fallback_reason,
|
||||
object_key = excluded.object_key,
|
||||
expires_at = excluded.expires_at,
|
||||
updated_at = excluded.updated_at
|
||||
`,
|
||||
transfer.ID,
|
||||
nullableString(transfer.SessionID),
|
||||
transfer.Kind,
|
||||
transfer.Name,
|
||||
nullableString(transfer.Content),
|
||||
transfer.SizeBytes,
|
||||
transfer.SenderDeviceID,
|
||||
transfer.ReceiverDeviceID,
|
||||
transfer.TransferStrategy,
|
||||
transfer.CurrentChannel,
|
||||
boolToInt(transfer.FallbackAllowed),
|
||||
transfer.FinalStatus,
|
||||
nullableString(transfer.FallbackReason),
|
||||
nullableString(transfer.ObjectKey),
|
||||
nullableTime(transfer.ExpiresAt),
|
||||
encodeTime(transfer.CreatedAt),
|
||||
encodeTime(transfer.UpdatedAt),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) PersistFallbackObject(ctx context.Context, object model.FallbackObject) error {
|
||||
_, err := c.db.ExecContext(ctx, `
|
||||
INSERT INTO fallback_objects (
|
||||
transfer_id, bucket, object_key, size_bytes, cleanup_state, created_at, expires_at, cleaned_at
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
ON CONFLICT(transfer_id) DO UPDATE SET
|
||||
bucket = excluded.bucket,
|
||||
object_key = excluded.object_key,
|
||||
size_bytes = excluded.size_bytes,
|
||||
cleanup_state = excluded.cleanup_state,
|
||||
expires_at = excluded.expires_at,
|
||||
cleaned_at = excluded.cleaned_at
|
||||
`,
|
||||
object.TransferID,
|
||||
"filefast-fallback",
|
||||
object.ObjectKey,
|
||||
object.SizeBytes,
|
||||
object.CleanupState,
|
||||
encodeTime(object.CreatedAt),
|
||||
encodeTime(object.ExpiresAt),
|
||||
nullableTime(object.CleanedAt),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) PersistRuntimeConfig(ctx context.Context, runtime model.RuntimeConfig) error {
|
||||
payload, err := json.Marshal(runtime)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = c.db.ExecContext(ctx, `
|
||||
INSERT INTO system_configs (config_key, config_value, updated_at)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(config_key) DO UPDATE SET
|
||||
config_value = excluded.config_value,
|
||||
updated_at = excluded.updated_at
|
||||
`, runtimeConfigKey, string(payload), encodeTime(time.Now()))
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) EnsureAdminUser(ctx context.Context, username, password string) error {
|
||||
username = strings.TrimSpace(username)
|
||||
password = strings.TrimSpace(password)
|
||||
if username == "" || password == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = c.db.ExecContext(ctx, `
|
||||
INSERT INTO admin_users (username, password_hash, is_active, created_at, updated_at)
|
||||
VALUES (?, ?, 1, ?, ?)
|
||||
ON CONFLICT(username) DO NOTHING
|
||||
`, username, string(hash), encodeTime(time.Now()), encodeTime(time.Now()))
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) ValidateAdminCredentials(ctx context.Context, username, password string) (bool, error) {
|
||||
var passwordHash string
|
||||
var isActive int
|
||||
err := c.db.QueryRowContext(ctx, `
|
||||
SELECT password_hash, is_active
|
||||
FROM admin_users
|
||||
WHERE username = ?
|
||||
`, strings.TrimSpace(username)).Scan(&passwordHash, &isActive)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
if isActive == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(strings.TrimSpace(password))) == nil, nil
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) loadDevices(ctx context.Context) ([]model.Device, error) {
|
||||
rows, err := c.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
type,
|
||||
COALESCE(user_agent, ''),
|
||||
COALESCE(network_group_key, ''),
|
||||
COALESCE(public_ip_hash, ''),
|
||||
is_online,
|
||||
COALESCE(last_seen_at, created_at),
|
||||
created_at
|
||||
FROM devices
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var devices []model.Device
|
||||
for rows.Next() {
|
||||
var (
|
||||
device model.Device
|
||||
isOnline int
|
||||
lastSeen string
|
||||
createdAt string
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&device.ID,
|
||||
&device.Name,
|
||||
&device.Type,
|
||||
&device.UserAgent,
|
||||
&device.NetworkGroupKey,
|
||||
&device.PublicIPHash,
|
||||
&isOnline,
|
||||
&lastSeen,
|
||||
&createdAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
device.IsOnline = isOnline != 0
|
||||
device.LastSeenAt, err = decodeTime(lastSeen)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
device.CreatedAt, err = decodeTime(createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
devices = append(devices, device)
|
||||
}
|
||||
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) loadRooms(ctx context.Context) ([]model.Room, error) {
|
||||
rows, err := c.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
code,
|
||||
creator_device_id,
|
||||
COALESCE(joiner_device_id, ''),
|
||||
status,
|
||||
created_at,
|
||||
expires_at
|
||||
FROM rooms
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var rooms []model.Room
|
||||
for rows.Next() {
|
||||
var (
|
||||
room model.Room
|
||||
createdAt string
|
||||
expiresAt string
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&room.Code,
|
||||
&room.CreatorDeviceID,
|
||||
&room.JoinerDeviceID,
|
||||
&room.Status,
|
||||
&createdAt,
|
||||
&expiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
room.CreatedAt, err = decodeTime(createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
room.ExpiresAt, err = decodeTime(expiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rooms = append(rooms, room)
|
||||
}
|
||||
|
||||
return rooms, rows.Err()
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) loadTransfers(ctx context.Context) ([]model.Transfer, error) {
|
||||
rows, err := c.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(session_id, ''),
|
||||
kind,
|
||||
name,
|
||||
COALESCE(content, ''),
|
||||
size_bytes,
|
||||
sender_device_id,
|
||||
receiver_device_id,
|
||||
transfer_strategy,
|
||||
current_channel,
|
||||
fallback_allowed,
|
||||
final_status,
|
||||
created_at,
|
||||
updated_at,
|
||||
COALESCE(fallback_reason, ''),
|
||||
COALESCE(object_key, ''),
|
||||
expires_at
|
||||
FROM transfers
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var transfers []model.Transfer
|
||||
for rows.Next() {
|
||||
var (
|
||||
transfer model.Transfer
|
||||
createdAt string
|
||||
updatedAt string
|
||||
expiresAt sql.NullString
|
||||
canFallback int
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&transfer.ID,
|
||||
&transfer.SessionID,
|
||||
&transfer.Kind,
|
||||
&transfer.Name,
|
||||
&transfer.Content,
|
||||
&transfer.SizeBytes,
|
||||
&transfer.SenderDeviceID,
|
||||
&transfer.ReceiverDeviceID,
|
||||
&transfer.TransferStrategy,
|
||||
&transfer.CurrentChannel,
|
||||
&canFallback,
|
||||
&transfer.FinalStatus,
|
||||
&createdAt,
|
||||
&updatedAt,
|
||||
&transfer.FallbackReason,
|
||||
&transfer.ObjectKey,
|
||||
&expiresAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transfer.FallbackAllowed = canFallback != 0
|
||||
transfer.CreatedAt, err = decodeTime(createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transfer.UpdatedAt, err = decodeTime(updatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if expiresAt.Valid && strings.TrimSpace(expiresAt.String) != "" {
|
||||
parsed, err := decodeTime(expiresAt.String)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
transfer.ExpiresAt = &parsed
|
||||
}
|
||||
transfers = append(transfers, transfer)
|
||||
}
|
||||
|
||||
return transfers, rows.Err()
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) loadFallbackObjects(ctx context.Context) ([]model.FallbackObject, error) {
|
||||
rows, err := c.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
transfer_id,
|
||||
object_key,
|
||||
size_bytes,
|
||||
created_at,
|
||||
expires_at,
|
||||
cleaned_at,
|
||||
cleanup_state
|
||||
FROM fallback_objects
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var objects []model.FallbackObject
|
||||
for rows.Next() {
|
||||
var (
|
||||
object model.FallbackObject
|
||||
createdAt string
|
||||
expiresAt string
|
||||
cleanedAt sql.NullString
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&object.TransferID,
|
||||
&object.ObjectKey,
|
||||
&object.SizeBytes,
|
||||
&createdAt,
|
||||
&expiresAt,
|
||||
&cleanedAt,
|
||||
&object.CleanupState,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
object.CreatedAt, err = decodeTime(createdAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
object.ExpiresAt, err = decodeTime(expiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cleanedAt.Valid && strings.TrimSpace(cleanedAt.String) != "" {
|
||||
parsed, err := decodeTime(cleanedAt.String)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
object.CleanedAt = &parsed
|
||||
}
|
||||
objects = append(objects, object)
|
||||
}
|
||||
|
||||
return objects, rows.Err()
|
||||
}
|
||||
|
||||
func (c *SQLiteClient) loadRuntimeConfig(ctx context.Context, fallback model.RuntimeConfig) (model.RuntimeConfig, error) {
|
||||
var raw string
|
||||
err := c.db.QueryRowContext(ctx, `
|
||||
SELECT config_value
|
||||
FROM system_configs
|
||||
WHERE config_key = ?
|
||||
`, runtimeConfigKey).Scan(&raw)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return fallback, nil
|
||||
}
|
||||
return model.RuntimeConfig{}, err
|
||||
}
|
||||
|
||||
return decodeRuntimeConfig([]byte(raw), fallback)
|
||||
}
|
||||
|
||||
func encodeTime(value time.Time) string {
|
||||
return value.UTC().Format(time.RFC3339Nano)
|
||||
}
|
||||
|
||||
func nullableTime(value *time.Time) any {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
return encodeTime(*value)
|
||||
}
|
||||
|
||||
func decodeTime(value string) (time.Time, error) {
|
||||
return time.Parse(time.RFC3339Nano, value)
|
||||
}
|
||||
|
||||
func nullableString(value string) any {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return nil
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func boolToInt(value bool) int {
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
478
backend/internal/store/memory.go
Normal file
478
backend/internal/store/memory.go
Normal file
@@ -0,0 +1,478 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/model"
|
||||
)
|
||||
|
||||
type Persistence interface {
|
||||
PersistDevice(context.Context, model.Device) error
|
||||
PersistRoom(context.Context, model.Room) error
|
||||
PersistTransfer(context.Context, model.Transfer) error
|
||||
PersistFallbackObject(context.Context, model.FallbackObject) error
|
||||
PersistRuntimeConfig(context.Context, model.RuntimeConfig) error
|
||||
}
|
||||
|
||||
type Snapshot struct {
|
||||
Devices []model.Device
|
||||
Rooms []model.Room
|
||||
Transfers []model.Transfer
|
||||
FallbackObjects []model.FallbackObject
|
||||
Runtime *model.RuntimeConfig
|
||||
}
|
||||
|
||||
type MemoryStore struct {
|
||||
mu sync.RWMutex
|
||||
devices map[string]model.Device
|
||||
rooms map[string]model.Room
|
||||
transfers map[string]model.Transfer
|
||||
fallbackObjects map[string]model.FallbackObject
|
||||
adminSessions map[string]model.AdminSession
|
||||
deviceSessions map[string]model.DeviceSession
|
||||
runtime model.RuntimeConfig
|
||||
persistence Persistence
|
||||
persistTimeout time.Duration
|
||||
onPersistError func(kind, id string, err error)
|
||||
}
|
||||
|
||||
const (
|
||||
activeDeviceWindow = 2 * time.Minute
|
||||
activeTransferWindow = 30 * time.Minute
|
||||
recentTerminalTransferWind = 24 * time.Hour
|
||||
)
|
||||
|
||||
func NewMemoryStore(runtime model.RuntimeConfig) *MemoryStore {
|
||||
return &MemoryStore{
|
||||
devices: make(map[string]model.Device),
|
||||
rooms: make(map[string]model.Room),
|
||||
transfers: make(map[string]model.Transfer),
|
||||
fallbackObjects: make(map[string]model.FallbackObject),
|
||||
adminSessions: make(map[string]model.AdminSession),
|
||||
deviceSessions: make(map[string]model.DeviceSession),
|
||||
runtime: runtime,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MemoryStore) UpsertDevice(device model.Device) model.Device {
|
||||
s.mu.Lock()
|
||||
s.devices[device.ID] = device
|
||||
s.mu.Unlock()
|
||||
s.persistDevice(device)
|
||||
return device
|
||||
}
|
||||
|
||||
func (s *MemoryStore) GetDevice(id string) (model.Device, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
device, ok := s.devices[id]
|
||||
return device, ok
|
||||
}
|
||||
|
||||
func (s *MemoryStore) ListDevices() []model.Device {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
devices := make([]model.Device, 0, len(s.devices))
|
||||
for _, device := range s.devices {
|
||||
devices = append(devices, device)
|
||||
}
|
||||
|
||||
return devices
|
||||
}
|
||||
|
||||
func (s *MemoryStore) UpsertRoom(room model.Room) model.Room {
|
||||
s.mu.Lock()
|
||||
s.rooms[room.Code] = room
|
||||
s.mu.Unlock()
|
||||
s.persistRoom(room)
|
||||
return room
|
||||
}
|
||||
|
||||
func (s *MemoryStore) GetRoom(code string) (model.Room, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
room, ok := s.rooms[code]
|
||||
return room, ok
|
||||
}
|
||||
|
||||
func (s *MemoryStore) UpsertTransfer(transfer model.Transfer) model.Transfer {
|
||||
s.mu.Lock()
|
||||
s.transfers[transfer.ID] = transfer
|
||||
s.mu.Unlock()
|
||||
s.persistTransfer(transfer)
|
||||
return transfer
|
||||
}
|
||||
|
||||
func (s *MemoryStore) GetTransfer(id string) (model.Transfer, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
transfer, ok := s.transfers[id]
|
||||
return transfer, ok
|
||||
}
|
||||
|
||||
func (s *MemoryStore) ListRecentTransfers(limit int) []model.Transfer {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
transfers := make([]model.Transfer, 0, len(s.transfers))
|
||||
for _, transfer := range s.transfers {
|
||||
if !isTransferVisible(transfer, now) {
|
||||
continue
|
||||
}
|
||||
transfers = append(transfers, transfer)
|
||||
}
|
||||
|
||||
sort.Slice(transfers, func(i, j int) bool {
|
||||
return transfers[i].UpdatedAt.After(transfers[j].UpdatedAt)
|
||||
})
|
||||
|
||||
if limit > 0 && len(transfers) > limit {
|
||||
return transfers[:limit]
|
||||
}
|
||||
|
||||
return transfers
|
||||
}
|
||||
|
||||
func (s *MemoryStore) ListPendingFallbackDownloads(receiverDeviceID string, limit int) []model.PendingFallbackDownload {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
if receiverDeviceID == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
downloads := make([]model.PendingFallbackDownload, 0, len(s.transfers))
|
||||
|
||||
for _, transfer := range s.transfers {
|
||||
if transfer.ReceiverDeviceID != receiverDeviceID || transfer.ObjectKey == "" || transfer.FinalStatus != model.TransferCompleted {
|
||||
continue
|
||||
}
|
||||
if transfer.ExpiresAt == nil || !transfer.ExpiresAt.After(now) {
|
||||
continue
|
||||
}
|
||||
|
||||
object, ok := s.fallbackObjects[transfer.ID]
|
||||
if !ok || object.CleanedAt != nil || object.ExpiresAt.Before(now) || object.CleanupState != "ready" {
|
||||
continue
|
||||
}
|
||||
|
||||
downloads = append(downloads, model.PendingFallbackDownload{
|
||||
TransferID: transfer.ID,
|
||||
Name: transfer.Name,
|
||||
SizeBytes: transfer.SizeBytes,
|
||||
CreatedAt: transfer.CreatedAt,
|
||||
ExpiresAt: object.ExpiresAt,
|
||||
DownloadPath: "/api/transfers/" + transfer.ID + "/fallback/download",
|
||||
SenderDeviceID: transfer.SenderDeviceID,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(downloads, func(i, j int) bool {
|
||||
return downloads[i].CreatedAt.After(downloads[j].CreatedAt)
|
||||
})
|
||||
|
||||
if limit > 0 && len(downloads) > limit {
|
||||
return downloads[:limit]
|
||||
}
|
||||
|
||||
return downloads
|
||||
}
|
||||
|
||||
func (s *MemoryStore) SaveFallbackObject(object model.FallbackObject) model.FallbackObject {
|
||||
s.mu.Lock()
|
||||
s.fallbackObjects[object.TransferID] = object
|
||||
s.mu.Unlock()
|
||||
s.persistFallbackObject(object)
|
||||
return object
|
||||
}
|
||||
|
||||
func (s *MemoryStore) GetFallbackObject(transferID string) (model.FallbackObject, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
object, ok := s.fallbackObjects[transferID]
|
||||
return object, ok
|
||||
}
|
||||
|
||||
func (s *MemoryStore) ListExpiredFallbackObjects(now time.Time) []model.FallbackObject {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
var objects []model.FallbackObject
|
||||
for _, object := range s.fallbackObjects {
|
||||
if object.CleanedAt == nil && !object.ExpiresAt.After(now) {
|
||||
objects = append(objects, object)
|
||||
}
|
||||
}
|
||||
|
||||
return objects
|
||||
}
|
||||
|
||||
func (s *MemoryStore) SaveAdminSession(session model.AdminSession) model.AdminSession {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.adminSessions[session.Token] = session
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *MemoryStore) SaveDeviceSession(session model.DeviceSession) model.DeviceSession {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.deviceSessions[session.DeviceID] = session
|
||||
return session
|
||||
}
|
||||
|
||||
func (s *MemoryStore) HasAdminSession(token string) bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
_, ok := s.adminSessions[token]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *MemoryStore) ValidateDeviceSession(deviceID, token string) bool {
|
||||
s.mu.RLock()
|
||||
session, ok := s.deviceSessions[deviceID]
|
||||
s.mu.RUnlock()
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if session.ExpiresAt.Before(time.Now()) {
|
||||
s.mu.Lock()
|
||||
delete(s.deviceSessions, deviceID)
|
||||
s.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
return session.Token == token
|
||||
}
|
||||
|
||||
func (s *MemoryStore) RuntimeConfig() model.RuntimeConfig {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.runtime
|
||||
}
|
||||
|
||||
func (s *MemoryStore) UpdateRuntimeConfig(runtime model.RuntimeConfig) model.RuntimeConfig {
|
||||
s.mu.Lock()
|
||||
s.runtime = runtime
|
||||
s.mu.Unlock()
|
||||
s.persistRuntimeConfig(runtime)
|
||||
return runtime
|
||||
}
|
||||
|
||||
func (s *MemoryStore) SnapshotStats() map[string]int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
onlineDevices := 0
|
||||
activeDevices := 0
|
||||
for _, device := range s.devices {
|
||||
if isDeviceActive(device, now) {
|
||||
activeDevices++
|
||||
}
|
||||
if device.IsOnline && isDeviceActive(device, now) {
|
||||
onlineDevices++
|
||||
}
|
||||
}
|
||||
|
||||
waitingRooms := 0
|
||||
for _, room := range s.rooms {
|
||||
if room.Status == model.RoomStatusWaiting && room.ExpiresAt.After(now) {
|
||||
waitingRooms++
|
||||
}
|
||||
}
|
||||
|
||||
fallbackPending := 0
|
||||
for _, object := range s.fallbackObjects {
|
||||
if object.CleanedAt == nil && object.CleanupState == "ready" {
|
||||
fallbackPending++
|
||||
}
|
||||
}
|
||||
|
||||
validTransfers := 0
|
||||
for _, transfer := range s.transfers {
|
||||
if isTransferVisible(transfer, now) {
|
||||
validTransfers++
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]int{
|
||||
"devices_total": activeDevices,
|
||||
"devices_online": onlineDevices,
|
||||
"rooms_waiting": waitingRooms,
|
||||
"transfers_total": validTransfers,
|
||||
"transfers_cumulative": len(s.transfers),
|
||||
"fallback_pending": fallbackPending,
|
||||
"admin_sessions": len(s.adminSessions),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MemoryStore) SnapshotMinIOStorage() model.MinIOStorageOverview {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
now := time.Now()
|
||||
usedBytes := int64(0)
|
||||
objectCount := 0
|
||||
|
||||
for _, object := range s.fallbackObjects {
|
||||
if object.CleanedAt != nil || !object.ExpiresAt.After(now) || object.CleanupState != "ready" {
|
||||
continue
|
||||
}
|
||||
|
||||
usedBytes += object.SizeBytes
|
||||
objectCount++
|
||||
}
|
||||
|
||||
capacityBytes := s.runtime.MinIOCapacityBytes
|
||||
if capacityBytes < 0 {
|
||||
capacityBytes = 0
|
||||
}
|
||||
|
||||
remainingBytes := capacityBytes - usedBytes
|
||||
if remainingBytes < 0 {
|
||||
remainingBytes = 0
|
||||
}
|
||||
|
||||
usagePercent := 0
|
||||
if capacityBytes > 0 {
|
||||
usagePercent = int((usedBytes * 100) / capacityBytes)
|
||||
if usagePercent > 100 {
|
||||
usagePercent = 100
|
||||
}
|
||||
}
|
||||
|
||||
return model.MinIOStorageOverview{
|
||||
Enabled: s.runtime.MinIOFallbackEnabled,
|
||||
UsedBytes: usedBytes,
|
||||
CapacityBytes: capacityBytes,
|
||||
RemainingBytes: remainingBytes,
|
||||
UsagePercent: usagePercent,
|
||||
ObjectCount: objectCount,
|
||||
}
|
||||
}
|
||||
|
||||
func isDeviceActive(device model.Device, now time.Time) bool {
|
||||
if device.LastSeenAt.IsZero() {
|
||||
return false
|
||||
}
|
||||
|
||||
return device.LastSeenAt.After(now.Add(-activeDeviceWindow))
|
||||
}
|
||||
|
||||
func isTransferVisible(transfer model.Transfer, now time.Time) bool {
|
||||
if transfer.UpdatedAt.IsZero() {
|
||||
transfer.UpdatedAt = transfer.CreatedAt
|
||||
}
|
||||
|
||||
if isTerminalTransferStatus(transfer.FinalStatus) {
|
||||
return transfer.UpdatedAt.After(now.Add(-recentTerminalTransferWind))
|
||||
}
|
||||
|
||||
return transfer.UpdatedAt.After(now.Add(-activeTransferWindow))
|
||||
}
|
||||
|
||||
func isTerminalTransferStatus(status string) bool {
|
||||
switch status {
|
||||
case model.TransferCompleted, model.TransferFailed, model.TransferCancelled:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MemoryStore) SetPersistence(persistence Persistence, timeout time.Duration, onError func(kind, id string, err error)) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.persistence = persistence
|
||||
s.persistTimeout = timeout
|
||||
s.onPersistError = onError
|
||||
}
|
||||
|
||||
func (s *MemoryStore) LoadSnapshot(snapshot Snapshot) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.devices = make(map[string]model.Device, len(snapshot.Devices))
|
||||
for _, device := range snapshot.Devices {
|
||||
s.devices[device.ID] = device
|
||||
}
|
||||
|
||||
s.rooms = make(map[string]model.Room, len(snapshot.Rooms))
|
||||
for _, room := range snapshot.Rooms {
|
||||
s.rooms[room.Code] = room
|
||||
}
|
||||
|
||||
s.transfers = make(map[string]model.Transfer, len(snapshot.Transfers))
|
||||
for _, transfer := range snapshot.Transfers {
|
||||
s.transfers[transfer.ID] = transfer
|
||||
}
|
||||
|
||||
s.fallbackObjects = make(map[string]model.FallbackObject, len(snapshot.FallbackObjects))
|
||||
for _, object := range snapshot.FallbackObjects {
|
||||
s.fallbackObjects[object.TransferID] = object
|
||||
}
|
||||
|
||||
if snapshot.Runtime != nil {
|
||||
s.runtime = *snapshot.Runtime
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MemoryStore) persistDevice(device model.Device) {
|
||||
s.persist(device.ID, "device", func(ctx context.Context, persistence Persistence) error {
|
||||
return persistence.PersistDevice(ctx, device)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MemoryStore) persistRoom(room model.Room) {
|
||||
s.persist(room.Code, "room", func(ctx context.Context, persistence Persistence) error {
|
||||
return persistence.PersistRoom(ctx, room)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MemoryStore) persistTransfer(transfer model.Transfer) {
|
||||
s.persist(transfer.ID, "transfer", func(ctx context.Context, persistence Persistence) error {
|
||||
return persistence.PersistTransfer(ctx, transfer)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MemoryStore) persistFallbackObject(object model.FallbackObject) {
|
||||
s.persist(object.TransferID, "fallback_object", func(ctx context.Context, persistence Persistence) error {
|
||||
return persistence.PersistFallbackObject(ctx, object)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MemoryStore) persistRuntimeConfig(runtime model.RuntimeConfig) {
|
||||
s.persist("transfer_policy", "runtime_config", func(ctx context.Context, persistence Persistence) error {
|
||||
return persistence.PersistRuntimeConfig(ctx, runtime)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *MemoryStore) persist(id, kind string, fn func(context.Context, Persistence) error) {
|
||||
s.mu.RLock()
|
||||
persistence := s.persistence
|
||||
timeout := s.persistTimeout
|
||||
onError := s.onPersistError
|
||||
s.mu.RUnlock()
|
||||
|
||||
if persistence == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if timeout <= 0 {
|
||||
timeout = 5 * time.Second
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
if err := fn(ctx, persistence); err != nil && onError != nil {
|
||||
onError(kind, id, err)
|
||||
}
|
||||
}
|
||||
368
backend/internal/ws/hub.go
Normal file
368
backend/internal/ws/hub.go
Normal file
@@ -0,0 +1,368 @@
|
||||
package ws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"filefast/backend/internal/model"
|
||||
"filefast/backend/internal/service"
|
||||
"filefast/backend/internal/storage"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
writeWait = 10 * time.Second
|
||||
pongWait = 60 * time.Second
|
||||
pingPeriod = (pongWait * 9) / 10
|
||||
maxMessageSize = 8 * 1024 * 1024
|
||||
)
|
||||
|
||||
type Hub struct {
|
||||
logger *slog.Logger
|
||||
deviceService *service.DeviceService
|
||||
backplane realtimeBackplane
|
||||
instanceID string
|
||||
clients map[string]*Client
|
||||
register chan *Client
|
||||
unregister chan *Client
|
||||
relay chan relayMessage
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
hub *Hub
|
||||
deviceID string
|
||||
conn *websocket.Conn
|
||||
send chan []byte
|
||||
}
|
||||
|
||||
type relayMessage struct {
|
||||
targetDeviceID string
|
||||
payload []byte
|
||||
}
|
||||
|
||||
type inboundMessage struct {
|
||||
Type string `json:"type"`
|
||||
TargetDeviceID string `json:"target_device_id"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
type backplaneEnvelope struct {
|
||||
Source string `json:"source"`
|
||||
TargetDeviceID string `json:"target_device_id,omitempty"`
|
||||
Payload json.RawMessage `json:"payload"`
|
||||
}
|
||||
|
||||
type realtimeBackplane interface {
|
||||
PublishRealtime(context.Context, string, []byte) error
|
||||
SubscribeRealtime(context.Context, ...string) (<-chan storage.RealtimeMessage, error)
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
origin := strings.TrimSpace(r.Header.Get("Origin"))
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(origin)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return originHostMatchesRequest(parsed.Host, r)
|
||||
},
|
||||
}
|
||||
|
||||
func NewHub(logger *slog.Logger, deviceService *service.DeviceService, backplane realtimeBackplane) *Hub {
|
||||
return &Hub{
|
||||
logger: logger,
|
||||
deviceService: deviceService,
|
||||
backplane: backplane,
|
||||
instanceID: fmt.Sprintf("%d", time.Now().UnixNano()),
|
||||
clients: make(map[string]*Client),
|
||||
register: make(chan *Client),
|
||||
unregister: make(chan *Client),
|
||||
relay: make(chan relayMessage),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) Run() {
|
||||
var backplaneMessages <-chan storage.RealtimeMessage
|
||||
if h.backplane != nil {
|
||||
messages, err := h.backplane.SubscribeRealtime(context.Background(), storage.RedisRealtimeRelayChannel, storage.RedisRealtimePresenceChannel)
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to subscribe redis realtime backplane", "error", err)
|
||||
} else {
|
||||
backplaneMessages = messages
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case client := <-h.register:
|
||||
h.clients[client.deviceID] = client
|
||||
h.deviceService.SetOnline(client.deviceID, true)
|
||||
h.broadcastPresence(client.deviceID, true)
|
||||
case client := <-h.unregister:
|
||||
if existing, ok := h.clients[client.deviceID]; ok && existing == client {
|
||||
delete(h.clients, client.deviceID)
|
||||
close(client.send)
|
||||
h.deviceService.SetOnline(client.deviceID, false)
|
||||
h.broadcastPresence(client.deviceID, false)
|
||||
}
|
||||
case message := <-h.relay:
|
||||
h.dispatchRelay(message.targetDeviceID, message.payload)
|
||||
case message, ok := <-backplaneMessages:
|
||||
if !ok {
|
||||
backplaneMessages = nil
|
||||
continue
|
||||
}
|
||||
h.handleBackplaneMessage(message)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) Handle(c *gin.Context) {
|
||||
deviceID := c.Query("deviceId")
|
||||
token := strings.TrimSpace(c.Query("deviceToken"))
|
||||
if deviceID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "deviceId is required"})
|
||||
return
|
||||
}
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "deviceToken is required"})
|
||||
return
|
||||
}
|
||||
if !h.deviceService.ValidateSession(deviceID, token) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid device credentials"})
|
||||
return
|
||||
}
|
||||
|
||||
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to upgrade websocket"})
|
||||
return
|
||||
}
|
||||
|
||||
client := &Client{
|
||||
hub: h,
|
||||
deviceID: deviceID,
|
||||
conn: conn,
|
||||
send: make(chan []byte, 32),
|
||||
}
|
||||
|
||||
h.register <- client
|
||||
go client.writePump()
|
||||
client.readPump()
|
||||
}
|
||||
|
||||
func (h *Hub) broadcastPresence(deviceID string, online bool) {
|
||||
envelope := model.SignalEnvelope{
|
||||
Type: "presence.update",
|
||||
DeviceID: deviceID,
|
||||
Payload: map[string]any{
|
||||
"online": online,
|
||||
},
|
||||
}
|
||||
|
||||
data, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to marshal presence update", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
h.broadcastPresencePayload(data)
|
||||
h.publishBackplane(storage.RedisRealtimePresenceChannel, "", data)
|
||||
}
|
||||
|
||||
func (h *Hub) broadcastPresencePayload(data []byte) {
|
||||
for _, client := range h.clients {
|
||||
client.send <- data
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) dispatchRelay(targetDeviceID string, payload []byte) {
|
||||
if client, ok := h.clients[targetDeviceID]; ok {
|
||||
client.send <- payload
|
||||
}
|
||||
|
||||
h.publishBackplane(storage.RedisRealtimeRelayChannel, targetDeviceID, payload)
|
||||
}
|
||||
|
||||
func (h *Hub) publishBackplane(channel, targetDeviceID string, payload []byte) {
|
||||
if h.backplane == nil {
|
||||
return
|
||||
}
|
||||
|
||||
envelope, err := json.Marshal(backplaneEnvelope{
|
||||
Source: h.instanceID,
|
||||
TargetDeviceID: targetDeviceID,
|
||||
Payload: payload,
|
||||
})
|
||||
if err != nil {
|
||||
h.logger.Warn("failed to marshal backplane envelope", "channel", channel, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := h.backplane.PublishRealtime(ctx, channel, envelope); err != nil {
|
||||
h.logger.Warn("failed to publish realtime backplane message", "channel", channel, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) handleBackplaneMessage(message storage.RealtimeMessage) {
|
||||
var envelope backplaneEnvelope
|
||||
if err := json.Unmarshal(message.Payload, &envelope); err != nil {
|
||||
h.logger.Warn("failed to decode realtime backplane message", "channel", message.Channel, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
if envelope.Source == h.instanceID {
|
||||
return
|
||||
}
|
||||
|
||||
switch message.Channel {
|
||||
case storage.RedisRealtimeRelayChannel:
|
||||
if client, ok := h.clients[envelope.TargetDeviceID]; ok {
|
||||
client.send <- envelope.Payload
|
||||
}
|
||||
case storage.RedisRealtimePresenceChannel:
|
||||
h.broadcastPresencePayload(envelope.Payload)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) readPump() {
|
||||
defer func() {
|
||||
c.hub.unregister <- c
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
c.conn.SetReadLimit(maxMessageSize)
|
||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
c.conn.SetPongHandler(func(string) error {
|
||||
c.conn.SetReadDeadline(time.Now().Add(pongWait))
|
||||
return nil
|
||||
})
|
||||
|
||||
for {
|
||||
var inbound inboundMessage
|
||||
if err := c.conn.ReadJSON(&inbound); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
envelope := model.SignalEnvelope{
|
||||
Type: inbound.Type,
|
||||
DeviceID: c.deviceID,
|
||||
TargetDeviceID: inbound.TargetDeviceID,
|
||||
Payload: inbound.Payload,
|
||||
}
|
||||
|
||||
data, err := json.Marshal(envelope)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
c.hub.relay <- relayMessage{
|
||||
targetDeviceID: inbound.TargetDeviceID,
|
||||
payload: data,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) writePump() {
|
||||
ticker := time.NewTicker(pingPeriod)
|
||||
defer func() {
|
||||
ticker.Stop()
|
||||
c.conn.Close()
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case message, ok := <-c.send:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if !ok {
|
||||
c.conn.WriteMessage(websocket.CloseMessage, []byte{})
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.conn.WriteMessage(websocket.TextMessage, message); err != nil {
|
||||
return
|
||||
}
|
||||
case <-ticker.C:
|
||||
c.conn.SetWriteDeadline(time.Now().Add(writeWait))
|
||||
if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func originHostMatchesRequest(originHost string, r *http.Request) bool {
|
||||
requestHosts := []string{
|
||||
strings.TrimSpace(r.Host),
|
||||
strings.TrimSpace(r.Header.Get("X-Forwarded-Host")),
|
||||
}
|
||||
|
||||
originName, err := normalizeHost(originHost)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, host := range requestHosts {
|
||||
if host == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
requestName, err := normalizeHost(host)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if requestName == originName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func normalizeHost(host string) (string, error) {
|
||||
host = strings.TrimSpace(host)
|
||||
if host == "" {
|
||||
return "", fmt.Errorf("empty host")
|
||||
}
|
||||
|
||||
if strings.Contains(host, "://") {
|
||||
parsed, err := url.Parse(host)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
host = parsed.Host
|
||||
}
|
||||
|
||||
name, _, err := net.SplitHostPort(host)
|
||||
if err == nil {
|
||||
host = name
|
||||
} else if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
|
||||
host = strings.TrimPrefix(strings.TrimSuffix(host, "]"), "[")
|
||||
}
|
||||
|
||||
host = strings.ToLower(strings.TrimSpace(host))
|
||||
switch host {
|
||||
case "127.0.0.1", "::1":
|
||||
return "localhost", nil
|
||||
default:
|
||||
return host, nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user