616 lines
19 KiB
Python
616 lines
19 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
import os
|
||
import sqlite3
|
||
import sys
|
||
import time
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
from typing import Iterable, Optional, Tuple
|
||
|
||
import requests
|
||
|
||
try:
|
||
# 可选:用于从 .env 文件加载环境变量(UPLOAD_JSON_SUBDIRS / UPLOAD_JSON_SUBDIR 等)
|
||
from dotenv import load_dotenv # type: ignore
|
||
except Exception: # pragma: no cover
|
||
load_dotenv = None # type: ignore
|
||
|
||
DEFAULT_ENDPOINT = (
|
||
os.environ.get("UPLOAD_ENDPOINT")
|
||
or "http://127.0.0.1:8317/v0/management/auth-files"
|
||
).strip()
|
||
DEFAULT_DB_NAME = (os.environ.get("UPLOAD_DB") or "upload_state.sqlite3").strip()
|
||
|
||
# 环境变量:指定 JSON 所在子目录(相对 --dir / 当前目录),可配置多个
|
||
# 例:
|
||
# Windows: set UPLOAD_JSON_SUBDIRS=data;data2
|
||
# bash: export UPLOAD_JSON_SUBDIRS="data:data2" (注意:在 Windows 上 os.pathsep 是 ;)
|
||
# 同时也支持逗号分隔:data,data2
|
||
ENV_JSON_SUBDIRS = "UPLOAD_JSON_SUBDIRS"
|
||
ENV_JSON_SUBDIR = "UPLOAD_JSON_SUBDIR" # 单个子目录(兼容/简化用)
|
||
|
||
# 多个上传站点(接口地址)列表,建议配置在 .env
|
||
# 例:UPLOAD_ENDPOINTS=http://a/v0/management/auth-files;http://b/v0/management/auth-files
|
||
ENV_ENDPOINTS = "UPLOAD_ENDPOINTS"
|
||
ENV_TOKEN = "UPLOAD_TOKEN"
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class UploadResult:
|
||
ok: bool
|
||
status_code: Optional[int]
|
||
text: str
|
||
|
||
|
||
def setup_logger() -> logging.Logger:
|
||
logger = logging.getLogger("uploader")
|
||
logger.setLevel(logging.INFO)
|
||
|
||
class _CnFormatter(logging.Formatter):
|
||
LEVEL_MAP = {
|
||
"DEBUG": "调试",
|
||
"INFO": "信息",
|
||
"WARNING": "警告",
|
||
"ERROR": "错误",
|
||
"CRITICAL": "严重",
|
||
}
|
||
|
||
def format(self, record: logging.LogRecord) -> str:
|
||
record.levelname_cn = self.LEVEL_MAP.get(record.levelname, record.levelname)
|
||
return super().format(record)
|
||
|
||
handler = logging.StreamHandler(sys.stdout)
|
||
handler.setLevel(logging.INFO)
|
||
formatter = _CnFormatter(
|
||
fmt="[%(asctime)s] [%(levelname_cn)s] %(message)s",
|
||
datefmt="%H:%M:%S",
|
||
)
|
||
handler.setFormatter(formatter)
|
||
|
||
# 防止重复添加 handler(例如某些 IDE 运行方式)
|
||
if not logger.handlers:
|
||
logger.addHandler(handler)
|
||
|
||
return logger
|
||
|
||
|
||
def init_db(db_path: Path) -> sqlite3.Connection:
|
||
conn = sqlite3.connect(str(db_path))
|
||
|
||
# v2: 按 endpoint + file_path 去重
|
||
conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS uploads_v2 (
|
||
endpoint TEXT NOT NULL,
|
||
file_path TEXT NOT NULL,
|
||
file_name TEXT NOT NULL,
|
||
sha256 TEXT NOT NULL,
|
||
size_bytes INTEGER NOT NULL,
|
||
mtime_ns INTEGER NOT NULL,
|
||
status TEXT NOT NULL, -- success | skipped | failed
|
||
http_status INTEGER,
|
||
response_text TEXT,
|
||
updated_at INTEGER NOT NULL,
|
||
PRIMARY KEY (endpoint, file_path)
|
||
)
|
||
"""
|
||
)
|
||
conn.execute(
|
||
"CREATE INDEX IF NOT EXISTS idx_uploads_v2_status ON uploads_v2(status)"
|
||
)
|
||
|
||
# 兼容迁移:如果旧表 uploads 存在且 v2 没数据,则把旧数据迁移到 v2(按 DEFAULT_ENDPOINT)
|
||
try:
|
||
old_exists = (
|
||
conn.execute(
|
||
"SELECT name FROM sqlite_master WHERE type='table' AND name='uploads'"
|
||
).fetchone()
|
||
is not None
|
||
)
|
||
if old_exists:
|
||
v2_count = conn.execute("SELECT COUNT(1) FROM uploads_v2").fetchone()[0]
|
||
if v2_count == 0:
|
||
conn.execute(
|
||
"""
|
||
INSERT OR IGNORE INTO uploads_v2(
|
||
endpoint, file_path, file_name, sha256, size_bytes, mtime_ns,
|
||
status, http_status, response_text, updated_at
|
||
)
|
||
SELECT
|
||
?, file_path, file_name, sha256, size_bytes, mtime_ns,
|
||
status, http_status, response_text, updated_at
|
||
FROM uploads
|
||
""",
|
||
(DEFAULT_ENDPOINT,),
|
||
)
|
||
except Exception:
|
||
# 迁移失败不影响主流程
|
||
pass
|
||
|
||
conn.commit()
|
||
return conn
|
||
|
||
|
||
def sha256_of_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
|
||
h = hashlib.sha256()
|
||
with path.open("rb") as f:
|
||
while True:
|
||
chunk = f.read(chunk_size)
|
||
if not chunk:
|
||
break
|
||
h.update(chunk)
|
||
return h.hexdigest()
|
||
|
||
|
||
def get_file_fingerprint(path: Path) -> Tuple[int, int, str]:
|
||
st = path.stat()
|
||
size = int(st.st_size)
|
||
mtime_ns = int(getattr(st, "st_mtime_ns", int(st.st_mtime * 1e9)))
|
||
digest = sha256_of_file(path)
|
||
return size, mtime_ns, digest
|
||
|
||
|
||
def is_file_stable(path: Path, stable_seconds: float = 1.5) -> bool:
|
||
"""避免上传正在写入中的文件:短暂等待并检查 size/mtime 是否变化。"""
|
||
try:
|
||
st1 = path.stat()
|
||
except FileNotFoundError:
|
||
return False
|
||
time.sleep(stable_seconds)
|
||
try:
|
||
st2 = path.stat()
|
||
except FileNotFoundError:
|
||
return False
|
||
|
||
return (st1.st_size == st2.st_size) and (st1.st_mtime_ns == st2.st_mtime_ns)
|
||
|
||
|
||
def iter_json_files(folder: Path) -> Iterable[Path]:
|
||
# 递归扫描:如果你只想扫描当前目录(不含子目录),把 rglob 改成 glob
|
||
for p in folder.rglob("*.json"):
|
||
if p.is_file():
|
||
yield p
|
||
|
||
|
||
def parse_list_from_env(var_name: str) -> list[str]:
|
||
raw = os.environ.get(var_name, "").strip()
|
||
if not raw:
|
||
return []
|
||
|
||
parts: list[str] = []
|
||
# 同时兼容 "a;b" / "a:b" / "a,b"
|
||
for token in raw.replace(",", os.pathsep).split(os.pathsep):
|
||
token = token.strip()
|
||
if token:
|
||
parts.append(token)
|
||
|
||
# 去重并保持顺序
|
||
seen: set[str] = set()
|
||
uniq: list[str] = []
|
||
for p in parts:
|
||
if p not in seen:
|
||
seen.add(p)
|
||
uniq.append(p)
|
||
return uniq
|
||
|
||
|
||
def parse_endpoints_from_env(fallback_endpoint: str) -> list[str]:
|
||
"""读取需要上传的 endpoint 列表。
|
||
|
||
- 优先使用环境变量 UPLOAD_ENDPOINTS
|
||
- endpoints 只支持用 ; 或 , 分隔(不要用 ':'),因为 URL 自身包含 'http://', 端口等 ':'
|
||
- 如果未配置,则回退到命令行 --endpoint
|
||
"""
|
||
|
||
raw = os.environ.get(ENV_ENDPOINTS, "").strip()
|
||
if raw:
|
||
parts = [p.strip() for p in raw.replace(",", ";").split(";")]
|
||
endpoints = [p for p in parts if p]
|
||
if endpoints:
|
||
return endpoints
|
||
|
||
return [fallback_endpoint]
|
||
|
||
|
||
def normalize_token(raw: str) -> str:
|
||
t = (raw or "").strip()
|
||
if not t:
|
||
return ""
|
||
if t.lower().startswith("bearer "):
|
||
return t.split(" ", 1)[1].strip()
|
||
return t
|
||
|
||
|
||
def parse_subdirs_from_env(base_dir: Path) -> list[Path]:
|
||
"""从环境变量读取 JSON 子目录列表。
|
||
|
||
支持:
|
||
- UPLOAD_JSON_SUBDIRS:多个子目录,用 os.pathsep 分隔(Windows 是 ;,Linux/mac 是 :)
|
||
- 也支持逗号分隔
|
||
- UPLOAD_JSON_SUBDIR:单个子目录(兼容/简化)
|
||
|
||
返回:绝对路径列表(不存在的会被过滤掉)
|
||
|
||
说明:
|
||
- 如果传入的是绝对路径,则直接使用
|
||
- 如果是相对路径,则拼接 base_dir
|
||
"""
|
||
|
||
raw = os.environ.get(ENV_JSON_SUBDIRS, "").strip()
|
||
if not raw:
|
||
raw = os.environ.get(ENV_JSON_SUBDIR, "").strip()
|
||
|
||
if not raw:
|
||
return []
|
||
|
||
parts = (
|
||
parse_list_from_env(ENV_JSON_SUBDIRS)
|
||
if os.environ.get(ENV_JSON_SUBDIRS)
|
||
else []
|
||
)
|
||
if not parts:
|
||
parts = [raw]
|
||
|
||
dirs: list[Path] = []
|
||
for sub in parts:
|
||
candidate = Path(sub).expanduser()
|
||
if candidate.is_absolute():
|
||
p = candidate.resolve()
|
||
else:
|
||
p = (base_dir / sub).expanduser().resolve()
|
||
|
||
if p.exists() and p.is_dir():
|
||
dirs.append(p)
|
||
|
||
# 去重并保持顺序
|
||
seen: set[str] = set()
|
||
uniq: list[Path] = []
|
||
for d in dirs:
|
||
s = str(d)
|
||
if s not in seen:
|
||
seen.add(s)
|
||
uniq.append(d)
|
||
|
||
return uniq
|
||
|
||
|
||
def db_has_success_or_skipped(
|
||
conn: sqlite3.Connection, endpoint: str, file_path: str
|
||
) -> bool:
|
||
row = conn.execute(
|
||
"SELECT status FROM uploads_v2 WHERE endpoint = ? AND file_path = ?",
|
||
(endpoint, file_path),
|
||
).fetchone()
|
||
if not row:
|
||
return False
|
||
return row[0] in ("success", "skipped")
|
||
|
||
|
||
def db_upsert(
|
||
conn: sqlite3.Connection,
|
||
*,
|
||
endpoint: str,
|
||
file_path: str,
|
||
file_name: str,
|
||
sha256: str,
|
||
size_bytes: int,
|
||
mtime_ns: int,
|
||
status: str,
|
||
http_status: Optional[int],
|
||
response_text: str,
|
||
) -> None:
|
||
now = int(time.time())
|
||
conn.execute(
|
||
"""
|
||
INSERT INTO uploads_v2(
|
||
endpoint, file_path, file_name, sha256, size_bytes, mtime_ns,
|
||
status, http_status, response_text, updated_at
|
||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||
ON CONFLICT(endpoint, file_path) DO UPDATE SET
|
||
file_name=excluded.file_name,
|
||
sha256=excluded.sha256,
|
||
size_bytes=excluded.size_bytes,
|
||
mtime_ns=excluded.mtime_ns,
|
||
status=excluded.status,
|
||
http_status=excluded.http_status,
|
||
response_text=excluded.response_text,
|
||
updated_at=excluded.updated_at
|
||
""",
|
||
(
|
||
endpoint,
|
||
file_path,
|
||
file_name,
|
||
sha256,
|
||
size_bytes,
|
||
mtime_ns,
|
||
status,
|
||
http_status,
|
||
response_text[:4000],
|
||
now,
|
||
),
|
||
)
|
||
conn.commit()
|
||
|
||
|
||
def should_treat_as_duplicate(resp: requests.Response) -> bool:
|
||
"""无法提前拿到服务器已有列表时,依赖响应判断“已存在”。
|
||
|
||
你可以按实际接口返回格式在这里加强判断逻辑。
|
||
"""
|
||
if resp.status_code in (409, 208):
|
||
return True
|
||
|
||
text = (resp.text or "").lower()
|
||
# 常见提示关键字(按需增删)
|
||
keywords = [
|
||
"already",
|
||
"exists",
|
||
"duplicate",
|
||
"重复",
|
||
"已存在",
|
||
"已上传",
|
||
]
|
||
return any(k in text for k in keywords)
|
||
|
||
|
||
def upload_file(
|
||
*,
|
||
endpoint: str,
|
||
token: str,
|
||
path: Path,
|
||
timeout_s: int,
|
||
verify_tls: bool,
|
||
extra_headers_json: Optional[str] = None,
|
||
) -> UploadResult:
|
||
headers = {
|
||
"Accept": "application/json, text/plain, */*",
|
||
"Authorization": f"Bearer {token}",
|
||
}
|
||
|
||
if extra_headers_json:
|
||
try:
|
||
extra = json.loads(extra_headers_json)
|
||
if not isinstance(extra, dict):
|
||
raise ValueError("extra headers must be a JSON object")
|
||
for k, v in extra.items():
|
||
headers[str(k)] = str(v)
|
||
except Exception as e:
|
||
return UploadResult(False, None, f"Invalid --extra-headers-json: {e}")
|
||
|
||
# requests 会自动生成 multipart boundary;不要手工设置 Content-Type(否则 boundary 不匹配)
|
||
with path.open("rb") as f:
|
||
files = {
|
||
"file": (path.name, f, "application/json"),
|
||
}
|
||
try:
|
||
resp = requests.post(
|
||
endpoint,
|
||
headers=headers,
|
||
files=files,
|
||
timeout=timeout_s,
|
||
verify=verify_tls,
|
||
)
|
||
except Exception as e:
|
||
return UploadResult(False, None, f"Request error: {e}")
|
||
|
||
# 只要 2xx 就认为成功
|
||
if 200 <= resp.status_code < 300:
|
||
return UploadResult(True, resp.status_code, resp.text or "")
|
||
|
||
# 如果服务器返回“已存在”,可当作跳过(ok=True 但由上层标记为 skipped)
|
||
if should_treat_as_duplicate(resp):
|
||
return UploadResult(True, resp.status_code, resp.text or "")
|
||
|
||
return UploadResult(False, resp.status_code, resp.text or "")
|
||
|
||
|
||
def main() -> int:
|
||
parser = argparse.ArgumentParser(description="Upload local JSON files to target site")
|
||
parser.add_argument("--dir", default=".", help="工作目录(默认当前目录)")
|
||
parser.add_argument("--endpoint", default=DEFAULT_ENDPOINT, help="上传接口 URL")
|
||
parser.add_argument(
|
||
"--token",
|
||
default=None,
|
||
help=f"上传鉴权 token(也可用环境变量 {ENV_TOKEN} 配置)",
|
||
)
|
||
parser.add_argument(
|
||
"--interval",
|
||
type=int,
|
||
default=120,
|
||
help="扫描间隔秒数(默认 120 秒)",
|
||
)
|
||
parser.add_argument(
|
||
"--db",
|
||
default=DEFAULT_DB_NAME,
|
||
help="本地状态数据库文件名(默认 upload_state.sqlite3)",
|
||
)
|
||
parser.add_argument(
|
||
"--timeout",
|
||
type=int,
|
||
default=60,
|
||
help="单次上传超时秒数(默认 60)",
|
||
)
|
||
parser.add_argument(
|
||
"--verify-tls",
|
||
action="store_true",
|
||
help="开启 TLS 证书校验(仅 https 有意义;默认关闭)",
|
||
)
|
||
parser.add_argument(
|
||
"--once",
|
||
action="store_true",
|
||
help="只扫描并上传一轮后退出(默认持续运行)",
|
||
)
|
||
parser.add_argument(
|
||
"--extra-headers-json",
|
||
default=None,
|
||
help='附加请求头(JSON对象),如 {"Origin":"...","Referer":"..."}',
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
logger = setup_logger()
|
||
|
||
# 如果安装了 python-dotenv,则自动加载当前目录下的 .env
|
||
if load_dotenv is not None:
|
||
load_dotenv()
|
||
|
||
base_dir = Path(args.dir).expanduser().resolve()
|
||
if not base_dir.exists() or not base_dir.is_dir():
|
||
logger.error(f"目录不存在或不是文件夹: {base_dir}")
|
||
return 2
|
||
|
||
# 从环境变量读取子目录;如果没配,则默认扫描 base_dir 本身
|
||
subdirs = parse_subdirs_from_env(base_dir)
|
||
scan_dirs = subdirs if subdirs else [base_dir]
|
||
|
||
db_path = Path(args.db).expanduser().resolve()
|
||
conn = init_db(db_path)
|
||
|
||
# token:优先命令行,其次环境变量/.env
|
||
token = normalize_token(args.token or os.environ.get(ENV_TOKEN, ""))
|
||
if not token:
|
||
logger.error(f"未配置 token:请使用 --token 或在 .env 设置 {ENV_TOKEN}")
|
||
return 2
|
||
|
||
endpoints = parse_endpoints_from_env(args.endpoint)
|
||
|
||
logger.info(f"工作目录: {base_dir}")
|
||
if subdirs:
|
||
logger.info(
|
||
f"扫描子目录: {', '.join(str(d) for d in scan_dirs)}(来自环境变量 {ENV_JSON_SUBDIRS}/{ENV_JSON_SUBDIR})"
|
||
)
|
||
else:
|
||
logger.info("未设置子目录环境变量:将扫描工作目录本身")
|
||
logger.info(f"上传站点数: {len(endpoints)}")
|
||
for i, ep in enumerate(endpoints, 1):
|
||
logger.info(f"站点[{i}]: {ep}")
|
||
logger.info(f"状态库: {db_path}")
|
||
logger.info(f"扫描间隔: {args.interval} 秒")
|
||
|
||
while True:
|
||
any_new = False
|
||
|
||
# 汇总所有扫描目录的文件
|
||
all_files: list[Path] = []
|
||
for d in scan_dirs:
|
||
all_files.extend(iter_json_files(d))
|
||
|
||
# 为了可控,按 mtime 升序上传(更接近“按生成时间”)
|
||
files = sorted(all_files, key=lambda p: p.stat().st_mtime_ns)
|
||
|
||
for path in files:
|
||
file_path_str = str(path)
|
||
|
||
for endpoint in endpoints:
|
||
if db_has_success_or_skipped(conn, endpoint, file_path_str):
|
||
continue
|
||
|
||
if not is_file_stable(path):
|
||
logger.info(f"跳过(正在写入): {path.name}")
|
||
break
|
||
|
||
any_new = True
|
||
|
||
try:
|
||
size_bytes, mtime_ns, digest = get_file_fingerprint(path)
|
||
except Exception as e:
|
||
logger.error(f"失败(读取/计算哈希): {path.name}: {e}")
|
||
db_upsert(
|
||
conn,
|
||
endpoint=endpoint,
|
||
file_path=file_path_str,
|
||
file_name=path.name,
|
||
sha256="",
|
||
size_bytes=0,
|
||
mtime_ns=0,
|
||
status="failed",
|
||
http_status=None,
|
||
response_text=str(e),
|
||
)
|
||
continue
|
||
result = upload_file(
|
||
endpoint=endpoint,
|
||
token=token,
|
||
path=path,
|
||
timeout_s=args.timeout,
|
||
verify_tls=args.verify_tls,
|
||
extra_headers_json=args.extra_headers_json,
|
||
)
|
||
|
||
if result.ok:
|
||
status = "success"
|
||
if result.status_code is not None and not (
|
||
200 <= result.status_code < 300
|
||
):
|
||
status = "skipped"
|
||
|
||
if status == "success":
|
||
logger.info(
|
||
f"成功: {path.name} -> {endpoint} | HTTP={result.status_code}"
|
||
)
|
||
else:
|
||
logger.info(
|
||
f"跳过(已存在): {path.name} -> {endpoint} | HTTP={result.status_code}"
|
||
)
|
||
|
||
db_upsert(
|
||
conn,
|
||
endpoint=endpoint,
|
||
file_path=file_path_str,
|
||
file_name=path.name,
|
||
sha256=digest,
|
||
size_bytes=size_bytes,
|
||
mtime_ns=mtime_ns,
|
||
status=status,
|
||
http_status=result.status_code,
|
||
response_text=result.text,
|
||
)
|
||
else:
|
||
logger.error(
|
||
f"失败: {path.name} -> {endpoint} | HTTP={result.status_code} | {result.text[:300]}"
|
||
)
|
||
db_upsert(
|
||
conn,
|
||
endpoint=endpoint,
|
||
file_path=file_path_str,
|
||
file_name=path.name,
|
||
sha256=digest,
|
||
size_bytes=size_bytes,
|
||
mtime_ns=mtime_ns,
|
||
status="failed",
|
||
http_status=result.status_code,
|
||
response_text=result.text,
|
||
)
|
||
|
||
if args.once:
|
||
logger.info("已完成(单次模式)。")
|
||
return 0
|
||
|
||
wait_s = max(5, int(args.interval))
|
||
|
||
if not any_new:
|
||
# 倒计时读秒(同一行刷新),避免刷屏
|
||
for remaining in range(wait_s, 0, -1):
|
||
sys.stdout.write(
|
||
f"\r[{time.strftime('%H:%M:%S')}] [信息] 暂无新文件,等待中... {remaining} 秒"
|
||
)
|
||
sys.stdout.flush()
|
||
time.sleep(1)
|
||
sys.stdout.write("\n")
|
||
sys.stdout.flush()
|
||
else:
|
||
time.sleep(wait_s)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|