618 lines
14 KiB
Go
618 lines
14 KiB
Go
package storage
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"filefast/backend/internal/config"
|
|
"filefast/backend/internal/model"
|
|
"filefast/backend/internal/store"
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
type SQLiteClient struct {
|
|
db *sql.DB
|
|
path string
|
|
}
|
|
|
|
func NewSQLiteClient(cfg config.SQLiteConfig) (*SQLiteClient, error) {
|
|
path := strings.TrimSpace(cfg.Path)
|
|
if path == "" {
|
|
return nil, errors.New("sqlite path is required")
|
|
}
|
|
|
|
if dir := filepath.Dir(path); dir != "" && dir != "." {
|
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
db, err := sql.Open("sqlite", path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
db.SetMaxOpenConns(1)
|
|
db.SetMaxIdleConns(1)
|
|
db.SetConnMaxLifetime(0)
|
|
|
|
client := &SQLiteClient{
|
|
db: db,
|
|
path: path,
|
|
}
|
|
|
|
for _, pragma := range []string{
|
|
"PRAGMA journal_mode = WAL",
|
|
"PRAGMA foreign_keys = ON",
|
|
"PRAGMA busy_timeout = 5000",
|
|
} {
|
|
if _, err := db.Exec(pragma); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func (c *SQLiteClient) Ping(ctx context.Context) error {
|
|
return c.db.PingContext(ctx)
|
|
}
|
|
|
|
func (c *SQLiteClient) EnsureSchema(ctx context.Context) error {
|
|
sqlBytes, err := readFirstExistingFile(
|
|
filepath.Join("sql", "init_sqlite.sql"),
|
|
filepath.Join("backend", "sql", "init_sqlite.sql"),
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = c.db.ExecContext(ctx, string(sqlBytes))
|
|
return err
|
|
}
|
|
|
|
func (c *SQLiteClient) ResetOnlineDevices(ctx context.Context) error {
|
|
_, err := c.db.ExecContext(ctx, `UPDATE devices SET is_online = 0 WHERE is_online = 1`)
|
|
return err
|
|
}
|
|
|
|
func (c *SQLiteClient) ResetOperationalData(ctx context.Context) error {
|
|
tx, err := c.db.BeginTx(ctx, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback()
|
|
|
|
for _, statement := range []string{
|
|
`DELETE FROM fallback_objects`,
|
|
`DELETE FROM transfers`,
|
|
`DELETE FROM sessions`,
|
|
`DELETE FROM rooms`,
|
|
`DELETE FROM devices`,
|
|
} {
|
|
if _, err := tx.ExecContext(ctx, statement); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (c *SQLiteClient) LoadSnapshot(ctx context.Context, fallbackRuntime model.RuntimeConfig) (store.Snapshot, error) {
|
|
snapshot := store.Snapshot{
|
|
Runtime: &fallbackRuntime,
|
|
}
|
|
|
|
devices, err := c.loadDevices(ctx)
|
|
if err != nil {
|
|
return store.Snapshot{}, err
|
|
}
|
|
snapshot.Devices = devices
|
|
|
|
rooms, err := c.loadRooms(ctx)
|
|
if err != nil {
|
|
return store.Snapshot{}, err
|
|
}
|
|
snapshot.Rooms = rooms
|
|
|
|
transfers, err := c.loadTransfers(ctx)
|
|
if err != nil {
|
|
return store.Snapshot{}, err
|
|
}
|
|
snapshot.Transfers = transfers
|
|
|
|
fallbackObjects, err := c.loadFallbackObjects(ctx)
|
|
if err != nil {
|
|
return store.Snapshot{}, err
|
|
}
|
|
snapshot.FallbackObjects = fallbackObjects
|
|
|
|
runtimeCfg, err := c.loadRuntimeConfig(ctx, fallbackRuntime)
|
|
if err != nil {
|
|
return store.Snapshot{}, err
|
|
}
|
|
snapshot.Runtime = &runtimeCfg
|
|
|
|
return snapshot, nil
|
|
}
|
|
|
|
func (c *SQLiteClient) PersistDevice(ctx context.Context, device model.Device) error {
|
|
_, err := c.db.ExecContext(ctx, `
|
|
INSERT INTO devices (
|
|
id, device_code, name, type, user_agent, network_group_key, public_ip_hash,
|
|
is_online, last_seen_at, created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
device_code = excluded.device_code,
|
|
name = excluded.name,
|
|
type = excluded.type,
|
|
user_agent = excluded.user_agent,
|
|
network_group_key = excluded.network_group_key,
|
|
public_ip_hash = excluded.public_ip_hash,
|
|
is_online = excluded.is_online,
|
|
last_seen_at = excluded.last_seen_at,
|
|
updated_at = excluded.updated_at
|
|
`,
|
|
device.ID,
|
|
device.ID,
|
|
device.Name,
|
|
device.Type,
|
|
nullableString(device.UserAgent),
|
|
nullableString(device.NetworkGroupKey),
|
|
nullableString(device.PublicIPHash),
|
|
boolToInt(device.IsOnline),
|
|
encodeTime(device.LastSeenAt),
|
|
encodeTime(device.CreatedAt),
|
|
encodeTime(time.Now()),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (c *SQLiteClient) PersistRoom(ctx context.Context, room model.Room) error {
|
|
_, err := c.db.ExecContext(ctx, `
|
|
INSERT INTO rooms (
|
|
code, creator_device_id, joiner_device_id, status, created_at, expires_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(code) DO UPDATE SET
|
|
creator_device_id = excluded.creator_device_id,
|
|
joiner_device_id = excluded.joiner_device_id,
|
|
status = excluded.status,
|
|
expires_at = excluded.expires_at,
|
|
updated_at = excluded.updated_at
|
|
`,
|
|
room.Code,
|
|
room.CreatorDeviceID,
|
|
nullableString(room.JoinerDeviceID),
|
|
room.Status,
|
|
encodeTime(room.CreatedAt),
|
|
encodeTime(room.ExpiresAt),
|
|
encodeTime(time.Now()),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (c *SQLiteClient) PersistTransfer(ctx context.Context, transfer model.Transfer) error {
|
|
_, err := c.db.ExecContext(ctx, `
|
|
INSERT INTO transfers (
|
|
id, session_id, kind, name, content, size_bytes,
|
|
sender_device_id, receiver_device_id, transfer_strategy, current_channel,
|
|
fallback_allowed, final_status, fallback_reason, object_key, expires_at,
|
|
created_at, updated_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(id) DO UPDATE SET
|
|
session_id = excluded.session_id,
|
|
kind = excluded.kind,
|
|
name = excluded.name,
|
|
content = excluded.content,
|
|
size_bytes = excluded.size_bytes,
|
|
sender_device_id = excluded.sender_device_id,
|
|
receiver_device_id = excluded.receiver_device_id,
|
|
transfer_strategy = excluded.transfer_strategy,
|
|
current_channel = excluded.current_channel,
|
|
fallback_allowed = excluded.fallback_allowed,
|
|
final_status = excluded.final_status,
|
|
fallback_reason = excluded.fallback_reason,
|
|
object_key = excluded.object_key,
|
|
expires_at = excluded.expires_at,
|
|
updated_at = excluded.updated_at
|
|
`,
|
|
transfer.ID,
|
|
nullableString(transfer.SessionID),
|
|
transfer.Kind,
|
|
transfer.Name,
|
|
nullableString(transfer.Content),
|
|
transfer.SizeBytes,
|
|
transfer.SenderDeviceID,
|
|
transfer.ReceiverDeviceID,
|
|
transfer.TransferStrategy,
|
|
transfer.CurrentChannel,
|
|
boolToInt(transfer.FallbackAllowed),
|
|
transfer.FinalStatus,
|
|
nullableString(transfer.FallbackReason),
|
|
nullableString(transfer.ObjectKey),
|
|
nullableTime(transfer.ExpiresAt),
|
|
encodeTime(transfer.CreatedAt),
|
|
encodeTime(transfer.UpdatedAt),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (c *SQLiteClient) PersistFallbackObject(ctx context.Context, object model.FallbackObject) error {
|
|
_, err := c.db.ExecContext(ctx, `
|
|
INSERT INTO fallback_objects (
|
|
transfer_id, bucket, object_key, size_bytes, cleanup_state, created_at, expires_at, cleaned_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(transfer_id) DO UPDATE SET
|
|
bucket = excluded.bucket,
|
|
object_key = excluded.object_key,
|
|
size_bytes = excluded.size_bytes,
|
|
cleanup_state = excluded.cleanup_state,
|
|
expires_at = excluded.expires_at,
|
|
cleaned_at = excluded.cleaned_at
|
|
`,
|
|
object.TransferID,
|
|
"filefast-fallback",
|
|
object.ObjectKey,
|
|
object.SizeBytes,
|
|
object.CleanupState,
|
|
encodeTime(object.CreatedAt),
|
|
encodeTime(object.ExpiresAt),
|
|
nullableTime(object.CleanedAt),
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (c *SQLiteClient) PersistRuntimeConfig(ctx context.Context, runtime model.RuntimeConfig) error {
|
|
payload, err := json.Marshal(runtime)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = c.db.ExecContext(ctx, `
|
|
INSERT INTO system_configs (config_key, config_value, updated_at)
|
|
VALUES (?, ?, ?)
|
|
ON CONFLICT(config_key) DO UPDATE SET
|
|
config_value = excluded.config_value,
|
|
updated_at = excluded.updated_at
|
|
`, runtimeConfigKey, string(payload), encodeTime(time.Now()))
|
|
return err
|
|
}
|
|
|
|
func (c *SQLiteClient) EnsureAdminUser(ctx context.Context, username, password string) error {
|
|
username = strings.TrimSpace(username)
|
|
password = strings.TrimSpace(password)
|
|
if username == "" || password == "" {
|
|
return nil
|
|
}
|
|
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = c.db.ExecContext(ctx, `
|
|
INSERT INTO admin_users (username, password_hash, is_active, created_at, updated_at)
|
|
VALUES (?, ?, 1, ?, ?)
|
|
ON CONFLICT(username) DO NOTHING
|
|
`, username, string(hash), encodeTime(time.Now()), encodeTime(time.Now()))
|
|
return err
|
|
}
|
|
|
|
func (c *SQLiteClient) ValidateAdminCredentials(ctx context.Context, username, password string) (bool, error) {
|
|
var passwordHash string
|
|
var isActive int
|
|
err := c.db.QueryRowContext(ctx, `
|
|
SELECT password_hash, is_active
|
|
FROM admin_users
|
|
WHERE username = ?
|
|
`, strings.TrimSpace(username)).Scan(&passwordHash, &isActive)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
if isActive == 0 {
|
|
return false, nil
|
|
}
|
|
|
|
return bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(strings.TrimSpace(password))) == nil, nil
|
|
}
|
|
|
|
func (c *SQLiteClient) Close() error {
|
|
return c.db.Close()
|
|
}
|
|
|
|
func (c *SQLiteClient) loadDevices(ctx context.Context) ([]model.Device, error) {
|
|
rows, err := c.db.QueryContext(ctx, `
|
|
SELECT
|
|
id,
|
|
name,
|
|
type,
|
|
COALESCE(user_agent, ''),
|
|
COALESCE(network_group_key, ''),
|
|
COALESCE(public_ip_hash, ''),
|
|
is_online,
|
|
COALESCE(last_seen_at, created_at),
|
|
created_at
|
|
FROM devices
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var devices []model.Device
|
|
for rows.Next() {
|
|
var (
|
|
device model.Device
|
|
isOnline int
|
|
lastSeen string
|
|
createdAt string
|
|
)
|
|
if err := rows.Scan(
|
|
&device.ID,
|
|
&device.Name,
|
|
&device.Type,
|
|
&device.UserAgent,
|
|
&device.NetworkGroupKey,
|
|
&device.PublicIPHash,
|
|
&isOnline,
|
|
&lastSeen,
|
|
&createdAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
device.IsOnline = isOnline != 0
|
|
device.LastSeenAt, err = decodeTime(lastSeen)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
device.CreatedAt, err = decodeTime(createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
devices = append(devices, device)
|
|
}
|
|
|
|
return devices, rows.Err()
|
|
}
|
|
|
|
func (c *SQLiteClient) loadRooms(ctx context.Context) ([]model.Room, error) {
|
|
rows, err := c.db.QueryContext(ctx, `
|
|
SELECT
|
|
code,
|
|
creator_device_id,
|
|
COALESCE(joiner_device_id, ''),
|
|
status,
|
|
created_at,
|
|
expires_at
|
|
FROM rooms
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var rooms []model.Room
|
|
for rows.Next() {
|
|
var (
|
|
room model.Room
|
|
createdAt string
|
|
expiresAt string
|
|
)
|
|
if err := rows.Scan(
|
|
&room.Code,
|
|
&room.CreatorDeviceID,
|
|
&room.JoinerDeviceID,
|
|
&room.Status,
|
|
&createdAt,
|
|
&expiresAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
room.CreatedAt, err = decodeTime(createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
room.ExpiresAt, err = decodeTime(expiresAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
rooms = append(rooms, room)
|
|
}
|
|
|
|
return rooms, rows.Err()
|
|
}
|
|
|
|
func (c *SQLiteClient) loadTransfers(ctx context.Context) ([]model.Transfer, error) {
|
|
rows, err := c.db.QueryContext(ctx, `
|
|
SELECT
|
|
id,
|
|
COALESCE(session_id, ''),
|
|
kind,
|
|
name,
|
|
COALESCE(content, ''),
|
|
size_bytes,
|
|
sender_device_id,
|
|
receiver_device_id,
|
|
transfer_strategy,
|
|
current_channel,
|
|
fallback_allowed,
|
|
final_status,
|
|
created_at,
|
|
updated_at,
|
|
COALESCE(fallback_reason, ''),
|
|
COALESCE(object_key, ''),
|
|
expires_at
|
|
FROM transfers
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var transfers []model.Transfer
|
|
for rows.Next() {
|
|
var (
|
|
transfer model.Transfer
|
|
createdAt string
|
|
updatedAt string
|
|
expiresAt sql.NullString
|
|
canFallback int
|
|
)
|
|
if err := rows.Scan(
|
|
&transfer.ID,
|
|
&transfer.SessionID,
|
|
&transfer.Kind,
|
|
&transfer.Name,
|
|
&transfer.Content,
|
|
&transfer.SizeBytes,
|
|
&transfer.SenderDeviceID,
|
|
&transfer.ReceiverDeviceID,
|
|
&transfer.TransferStrategy,
|
|
&transfer.CurrentChannel,
|
|
&canFallback,
|
|
&transfer.FinalStatus,
|
|
&createdAt,
|
|
&updatedAt,
|
|
&transfer.FallbackReason,
|
|
&transfer.ObjectKey,
|
|
&expiresAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
transfer.FallbackAllowed = canFallback != 0
|
|
transfer.CreatedAt, err = decodeTime(createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
transfer.UpdatedAt, err = decodeTime(updatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if expiresAt.Valid && strings.TrimSpace(expiresAt.String) != "" {
|
|
parsed, err := decodeTime(expiresAt.String)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
transfer.ExpiresAt = &parsed
|
|
}
|
|
transfers = append(transfers, transfer)
|
|
}
|
|
|
|
return transfers, rows.Err()
|
|
}
|
|
|
|
func (c *SQLiteClient) loadFallbackObjects(ctx context.Context) ([]model.FallbackObject, error) {
|
|
rows, err := c.db.QueryContext(ctx, `
|
|
SELECT
|
|
transfer_id,
|
|
object_key,
|
|
size_bytes,
|
|
created_at,
|
|
expires_at,
|
|
cleaned_at,
|
|
cleanup_state
|
|
FROM fallback_objects
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var objects []model.FallbackObject
|
|
for rows.Next() {
|
|
var (
|
|
object model.FallbackObject
|
|
createdAt string
|
|
expiresAt string
|
|
cleanedAt sql.NullString
|
|
)
|
|
if err := rows.Scan(
|
|
&object.TransferID,
|
|
&object.ObjectKey,
|
|
&object.SizeBytes,
|
|
&createdAt,
|
|
&expiresAt,
|
|
&cleanedAt,
|
|
&object.CleanupState,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
object.CreatedAt, err = decodeTime(createdAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
object.ExpiresAt, err = decodeTime(expiresAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if cleanedAt.Valid && strings.TrimSpace(cleanedAt.String) != "" {
|
|
parsed, err := decodeTime(cleanedAt.String)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
object.CleanedAt = &parsed
|
|
}
|
|
objects = append(objects, object)
|
|
}
|
|
|
|
return objects, rows.Err()
|
|
}
|
|
|
|
func (c *SQLiteClient) loadRuntimeConfig(ctx context.Context, fallback model.RuntimeConfig) (model.RuntimeConfig, error) {
|
|
var raw string
|
|
err := c.db.QueryRowContext(ctx, `
|
|
SELECT config_value
|
|
FROM system_configs
|
|
WHERE config_key = ?
|
|
`, runtimeConfigKey).Scan(&raw)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return fallback, nil
|
|
}
|
|
return model.RuntimeConfig{}, err
|
|
}
|
|
|
|
return decodeRuntimeConfig([]byte(raw), fallback)
|
|
}
|
|
|
|
func encodeTime(value time.Time) string {
|
|
return value.UTC().Format(time.RFC3339Nano)
|
|
}
|
|
|
|
func nullableTime(value *time.Time) any {
|
|
if value == nil {
|
|
return nil
|
|
}
|
|
return encodeTime(*value)
|
|
}
|
|
|
|
func decodeTime(value string) (time.Time, error) {
|
|
return time.Parse(time.RFC3339Nano, value)
|
|
}
|
|
|
|
func nullableString(value string) any {
|
|
if strings.TrimSpace(value) == "" {
|
|
return nil
|
|
}
|
|
return value
|
|
}
|
|
|
|
func boolToInt(value bool) int {
|
|
if value {
|
|
return 1
|
|
}
|
|
return 0
|
|
}
|