feat(*): 添加测试项目代码
This commit is contained in:
0
handlers/__init__.py
Normal file
0
handlers/__init__.py
Normal file
9
handlers/health.py
Normal file
9
handlers/health.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""健康检查端点。"""
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from response import ok
|
||||
|
||||
|
||||
async def health_handler(request: web.Request) -> web.Response:
|
||||
return ok(data={"health": "ok"})
|
||||
125
handlers/message.py
Normal file
125
handlers/message.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""消息发送处理器:JSON 解析、参数校验、QQ API 超时与重试。"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from config import QQ_API_MAX_RETRIES, QQ_API_TIMEOUT
|
||||
from response import error, ok
|
||||
|
||||
VALID_MSG_TYPES = {"text", "image", "file", "video"}
|
||||
|
||||
# 每种消息类型必填的字段
|
||||
REQUIRED_FIELDS: dict[str, list[str]] = {
|
||||
"text": ["msg"],
|
||||
"image": ["url"],
|
||||
"file": ["url"],
|
||||
"video": ["url"],
|
||||
}
|
||||
|
||||
|
||||
def _validate_payload(data: dict) -> tuple[dict | None, web.Response | None]:
|
||||
"""校验请求体,返回 (data, None) 或 (None, error_response)。"""
|
||||
group_id = data.get("group_id")
|
||||
user_id = data.get("user_id")
|
||||
|
||||
if not group_id and not user_id:
|
||||
return None, error("need group_id or user_id")
|
||||
|
||||
msg_type = data.get("type", "text")
|
||||
if msg_type not in VALID_MSG_TYPES:
|
||||
return None, error(f"invalid type: {msg_type}, must be one of {VALID_MSG_TYPES}")
|
||||
|
||||
# 检查必填字段
|
||||
missing = [f for f in REQUIRED_FIELDS.get(msg_type, []) if not data.get(f)]
|
||||
if missing:
|
||||
return None, error(f"missing required fields: {', '.join(missing)}")
|
||||
|
||||
return data, None
|
||||
|
||||
|
||||
async def _call_qq_api(coro_factory, request: web.Request) -> web.Response:
|
||||
"""带超时和重试的 QQ API 调用。"""
|
||||
logger = request.app["logger"]
|
||||
rid = request.get("request_id", "-")
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(1, QQ_API_MAX_RETRIES + 1):
|
||||
try:
|
||||
await asyncio.wait_for(coro_factory(), timeout=QQ_API_TIMEOUT)
|
||||
return ok()
|
||||
except asyncio.TimeoutError:
|
||||
last_exc = asyncio.TimeoutError()
|
||||
logger.warning(f"[{rid}] QQ API timeout, attempt {attempt}/{QQ_API_MAX_RETRIES}")
|
||||
except Exception as exc:
|
||||
last_exc = exc
|
||||
logger.error(f"[{rid}] QQ API error: {exc}, attempt {attempt}/{QQ_API_MAX_RETRIES}")
|
||||
|
||||
logger.error(f"[{rid}] QQ API failed after {QQ_API_MAX_RETRIES} retries: {last_exc}")
|
||||
return error(f"qq api failed: {last_exc}", code=502, status=502)
|
||||
|
||||
|
||||
async def webhook_handler(request: web.Request) -> web.Response:
|
||||
"""处理消息发送请求。"""
|
||||
# 安全解析 JSON(aiohttp 可能抛 JSONDecodeError 或 ContentTypeError)
|
||||
try:
|
||||
data = await request.json()
|
||||
except Exception:
|
||||
return error("invalid json")
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return error("request body must be a json object")
|
||||
|
||||
data, err = _validate_payload(data)
|
||||
if err:
|
||||
return err
|
||||
|
||||
msg_type = data.get("type", "text")
|
||||
group_id = data.get("group_id")
|
||||
user_id = data.get("user_id")
|
||||
msg = data.get("msg", "")
|
||||
url = data.get("url", "")
|
||||
|
||||
# 获取 ncatbot API 实例
|
||||
api = request.app["qq_api"]
|
||||
|
||||
if group_id:
|
||||
match msg_type:
|
||||
case "text":
|
||||
return await _call_qq_api(
|
||||
lambda: api.qq.send_group_text(group_id=group_id, text=msg), request
|
||||
)
|
||||
case "image":
|
||||
return await _call_qq_api(
|
||||
lambda: api.qq.send_group_image(group_id=group_id, image=url), request
|
||||
)
|
||||
case "file":
|
||||
return await _call_qq_api(
|
||||
lambda: api.qq.send_group_file(group_id=group_id, file=url, name=url.split("/")[-1]),
|
||||
request,
|
||||
)
|
||||
case "video":
|
||||
return await _call_qq_api(
|
||||
lambda: api.qq.send_group_video(group_id=group_id, video=url), request
|
||||
)
|
||||
else:
|
||||
match msg_type:
|
||||
case "text":
|
||||
return await _call_qq_api(
|
||||
lambda: api.qq.send_private_text(user_id=user_id, text=msg), request
|
||||
)
|
||||
case "image":
|
||||
return await _call_qq_api(
|
||||
lambda: api.qq.send_private_image(user_id=user_id, image=url), request
|
||||
)
|
||||
case "file":
|
||||
return await _call_qq_api(
|
||||
lambda: api.qq.send_private_file(user_id=user_id, file=url, name=url.split("/")[-1]),
|
||||
request,
|
||||
)
|
||||
case "video":
|
||||
return await _call_qq_api(
|
||||
lambda: api.qq.send_private_video(user_id=user_id, video=url), request
|
||||
)
|
||||
|
||||
return error("unreachable", code=500, status=500)
|
||||
102
handlers/upload.py
Normal file
102
handlers/upload.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""上传处理器:文件上传、大小/类型限制、异步写入、自动清理过期文件。"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from aiohttp import BodyPartReader, web
|
||||
|
||||
from config import ALLOWED_EXTENSIONS, MAX_UPLOAD_SIZE, UPLOAD_DIR
|
||||
from response import error, ok
|
||||
|
||||
logger = logging.getLogger("webhook-plugin.upload")
|
||||
|
||||
# 文件最大保留秒数(默认 24 小时)
|
||||
FILE_TTL_SECONDS: int = 24 * 60 * 60
|
||||
|
||||
|
||||
def _check_extension(filename: str) -> bool:
|
||||
"""检查文件扩展名是否在允许列表内。"""
|
||||
if not ALLOWED_EXTENSIONS:
|
||||
return True
|
||||
ext = Path(filename).suffix.lstrip(".").lower()
|
||||
return ext in ALLOWED_EXTENSIONS
|
||||
|
||||
|
||||
async def cleanup_expired_files() -> None:
|
||||
"""删除上传目录中超过 FILE_TTL_SECONDS 的文件。"""
|
||||
if not UPLOAD_DIR.exists():
|
||||
return
|
||||
now = time.time()
|
||||
for path in UPLOAD_DIR.iterdir():
|
||||
if path.is_file() and (now - path.stat().st_mtime) > FILE_TTL_SECONDS:
|
||||
try:
|
||||
path.unlink()
|
||||
logger.info("已清理过期文件: %s", path.name)
|
||||
except OSError as exc:
|
||||
logger.warning("清理文件失败 %s: %s", path.name, exc)
|
||||
|
||||
|
||||
async def upload_handler(request: web.Request) -> web.Response:
|
||||
"""接收 multipart/form-data 上传,保存到 uploads 目录,返回相对文件 ID。"""
|
||||
UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
reader = await request.multipart()
|
||||
saved_ids: list[str] = []
|
||||
|
||||
async for part in reader:
|
||||
if not isinstance(part, BodyPartReader) or not part.filename:
|
||||
continue
|
||||
|
||||
filename: str = Path(part.filename).name # 防路径穿越
|
||||
|
||||
if not _check_extension(filename):
|
||||
return error(f"file type not allowed: {filename}", code=415)
|
||||
|
||||
# 读取文件内容并检查大小
|
||||
chunks: list[bytes] = []
|
||||
total_size = 0
|
||||
while True:
|
||||
chunk = await part.read_chunk(65536)
|
||||
if not chunk:
|
||||
break
|
||||
total_size += len(chunk)
|
||||
if total_size > MAX_UPLOAD_SIZE:
|
||||
return error(
|
||||
f"file too large, max {MAX_UPLOAD_SIZE // (1024*1024)} MB",
|
||||
code=413,
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
if total_size == 0:
|
||||
return error("empty file", code=400)
|
||||
|
||||
# 同名文件自动重命名
|
||||
save_path = UPLOAD_DIR / filename
|
||||
stem = Path(filename).stem
|
||||
suffix = Path(filename).suffix
|
||||
counter = 1
|
||||
while save_path.exists():
|
||||
save_path = UPLOAD_DIR / f"{stem}_{counter}{suffix}"
|
||||
counter += 1
|
||||
|
||||
# 使用线程池写文件,避免阻塞事件循环
|
||||
data = b"".join(chunks)
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, _write_file, save_path, data)
|
||||
|
||||
# 返回相对路径作为文件 ID
|
||||
file_id = save_path.relative_to(UPLOAD_DIR).as_posix()
|
||||
saved_ids.append(file_id)
|
||||
|
||||
if not saved_ids:
|
||||
return error("no file uploaded", code=400)
|
||||
|
||||
return ok(data={"files": saved_ids, "path": saved_ids[0] if len(saved_ids) == 1 else None})
|
||||
|
||||
|
||||
def _write_file(path: Path, data: bytes) -> None:
|
||||
"""同步写文件,由线程池调用。"""
|
||||
with open(path, "wb") as f:
|
||||
f.write(data)
|
||||
Reference in New Issue
Block a user