feat: Phase 3 - 工具层 + 测试 + 数据库迁移

This commit is contained in:
2026-03-14 20:14:59 +08:00
parent 1836d118fe
commit 6bafd21e02
14 changed files with 1191 additions and 0 deletions

View File

@@ -1,3 +1,81 @@
"""
工具模块
"""
from .validators import (
UserRegistrationSchema,
UserLoginSchema,
SessionCreateSchema,
MessageSendSchema,
AgentRegistrationSchema,
GatewayRegistrationSchema,
validate_uuid,
validate_email,
validate_username,
validate_url,
sanitize_string,
ValidationUtils,
)
from .security import (
generate_token,
hash_password,
verify_password,
hash_token,
verify_token_hash,
generate_api_key,
secure_compare,
mask_sensitive_data,
RateLimiter,
IPWhitelist,
)
from .helpers import (
format_datetime,
parse_datetime,
format_duration,
truncate_string,
safe_json_loads,
safe_json_dumps,
generate_session_title,
calculate_timeout,
merge_dicts,
filter_none_values,
PaginationHelper,
)
__all__ = [
# Validators
'UserRegistrationSchema',
'UserLoginSchema',
'SessionCreateSchema',
'MessageSendSchema',
'AgentRegistrationSchema',
'GatewayRegistrationSchema',
'validate_uuid',
'validate_email',
'validate_username',
'validate_url',
'sanitize_string',
'ValidationUtils',
# Security
'generate_token',
'hash_password',
'verify_password',
'hash_token',
'verify_token_hash',
'generate_api_key',
'secure_compare',
'mask_sensitive_data',
'RateLimiter',
'IPWhitelist',
# Helpers
'format_datetime',
'parse_datetime',
'format_duration',
'truncate_string',
'safe_json_loads',
'safe_json_dumps',
'generate_session_title',
'calculate_timeout',
'merge_dicts',
'filter_none_values',
'PaginationHelper',
]

113
app/utils/helpers.py Normal file
View File

@@ -0,0 +1,113 @@
"""
辅助函数
"""
from datetime import datetime, timedelta
from typing import Optional, Any
import json
def format_datetime(dt: Optional[datetime]) -> Optional[str]:
"""格式化日期时间"""
if not dt:
return None
return dt.isoformat()
def parse_datetime(dt_str: str) -> Optional[datetime]:
"""解析日期时间字符串"""
if not dt_str:
return None
try:
return datetime.fromisoformat(dt_str)
except:
return None
def format_duration(seconds: int) -> str:
"""格式化时长"""
if seconds < 60:
return f"{seconds}s"
elif seconds < 3600:
return f"{seconds // 60}m"
elif seconds < 86400:
return f"{seconds // 3600}h"
else:
return f"{seconds // 86400}d"
def truncate_string(s: str, max_length: int = 100, suffix: str = "...") -> str:
"""截断字符串"""
if len(s) <= max_length:
return s
return s[:max_length - len(suffix)] + suffix
def safe_json_loads(data: str, default: Any = None) -> Any:
"""安全的 JSON 解析"""
try:
return json.loads(data)
except:
return default
def safe_json_dumps(data: Any, default: str = "{}") -> str:
"""安全的 JSON 序列化"""
try:
return json.dumps(data, ensure_ascii=False)
except:
return default
def generate_session_title() -> str:
"""生成默认会话标题"""
return f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
def calculate_timeout(start_time: datetime, timeout_seconds: int) -> bool:
"""检查是否超时"""
if not start_time:
return True
elapsed = (datetime.utcnow() - start_time).total_seconds()
return elapsed > timeout_seconds
def merge_dicts(base: dict, override: dict) -> dict:
"""合并字典"""
result = base.copy()
result.update(override)
return result
def filter_none_values(data: dict) -> dict:
"""过滤字典中的 None 值"""
return {k: v for k, v in data.items() if v is not None}
class PaginationHelper:
"""分页辅助类"""
@staticmethod
def paginate(query, page: int = 1, per_page: int = 20):
"""分页查询"""
if page < 1:
page = 1
if per_page < 1:
per_page = 20
if per_page > 100:
per_page = 100
total = query.count()
items = query.offset((page - 1) * per_page).limit(per_page).all()
pages = (total + per_page - 1) // per_page
return {
'items': items,
'total': total,
'page': page,
'per_page': per_page,
'pages': pages,
'has_next': page < pages,
'has_prev': page > 1,
}

121
app/utils/security.py Normal file
View File

@@ -0,0 +1,121 @@
"""
安全工具
"""
import hashlib
import hmac
import secrets
import bcrypt
from typing import Optional
def generate_token(length: int = 32) -> str:
"""生成随机 Token"""
return secrets.token_urlsafe(length)
def hash_password(password: str) -> str:
"""密码哈希"""
return bcrypt.hashpw(
password.encode('utf-8'),
bcrypt.gensalt()
).decode('utf-8')
def verify_password(password: str, password_hash: str) -> bool:
"""验证密码"""
return bcrypt.checkpw(
password.encode('utf-8'),
password_hash.encode('utf-8')
)
def hash_token(token: str) -> str:
"""Token 哈希(用于存储)"""
return hashlib.sha256(token.encode()).hexdigest()
def verify_token_hash(token: str, token_hash: str) -> bool:
"""验证 Token 哈希"""
return hash_token(token) == token_hash
def generate_api_key() -> str:
"""生成 API Key"""
return f"pit_{secrets.token_urlsafe(32)}"
def secure_compare(val1: str, val2: str) -> bool:
"""安全字符串比较(防时序攻击)"""
return hmac.compare_digest(val1.encode(), val2.encode())
def mask_sensitive_data(data: str, visible_chars: int = 4) -> str:
"""脱敏处理"""
if len(data) <= visible_chars * 2:
return '*' * len(data)
return data[:visible_chars] + '*' * (len(data) - visible_chars * 2) + data[-visible_chars:]
class RateLimiter:
"""简单内存限流器"""
def __init__(self):
self._storage = {}
def is_allowed(self, key: str, limit: int = 100, window: int = 60) -> bool:
"""检查是否允许请求"""
from time import time
now = time()
window_start = now - window
# 获取该 key 的请求记录
requests = self._storage.get(key, [])
# 清理过期记录
requests = [t for t in requests if t > window_start]
# 检查是否超过限制
if len(requests) >= limit:
self._storage[key] = requests
return False
# 记录新请求
requests.append(now)
self._storage[key] = requests
return True
class IPWhitelist:
"""IP 白名单"""
def __init__(self, allowed_ips: list = None):
self.allowed_ips = set(allowed_ips or [])
self.allow_all = '*' in self.allowed_ips
def is_allowed(self, ip: str) -> bool:
"""检查 IP 是否允许"""
if self.allow_all:
return True
# 支持 CIDR 格式
for allowed in self.allowed_ips:
if '/' in allowed:
if self._ip_in_cidr(ip, allowed):
return True
elif ip == allowed:
return True
return False
def _ip_in_cidr(self, ip: str, cidr: str) -> bool:
"""检查 IP 是否在 CIDR 范围内"""
try:
import ipaddress
network = ipaddress.ip_network(cidr, strict=False)
address = ipaddress.ip_address(ip)
return address in network
except:
return False

120
app/utils/validators.py Normal file
View File

@@ -0,0 +1,120 @@
"""
输入验证工具
"""
from typing import Optional
import re
from marshmallow import Schema, fields, validate, ValidationError
class UserRegistrationSchema(Schema):
"""用户注册验证"""
username = fields.Str(required=True, validate=validate.Length(min=3, max=80))
email = fields.Email(required=True)
password = fields.Str(required=True, validate=validate.Length(min=6, max=128))
nickname = fields.Str(validate=validate.Length(max=80), load_default=None)
class UserLoginSchema(Schema):
"""用户登录验证"""
username = fields.Str(required=True)
password = fields.Str(required=True)
class SessionCreateSchema(Schema):
"""会话创建验证"""
title = fields.Str(validate=validate.Length(max=200), load_default=None)
agent_id = fields.Str(validate=validate.Length(max=36), load_default=None)
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
class MessageSendSchema(Schema):
"""消息发送验证"""
session_id = fields.Str(required=True, validate=validate.Length(max=36))
content = fields.Str(required=True, validate=validate.Length(min=1, max=10000))
message_type = fields.Str(validate=validate.OneOf(['text', 'media', 'system']), load_default='text')
reply_to = fields.Str(validate=validate.Length(max=36), load_default=None)
class AgentRegistrationSchema(Schema):
"""Agent 注册验证"""
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
gateway_id = fields.Str(validate=validate.Length(max=36), load_default=None)
model = fields.Str(validate=validate.Length(max=80), load_default=None)
capabilities = fields.List(fields.Str(), load_default=[])
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
weight = fields.Int(validate=validate.Range(min=1, max=100), load_default=10)
class GatewayRegistrationSchema(Schema):
"""Gateway 注册验证"""
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
url = fields.Str(required=True, validate=validate.Length(max=256))
token = fields.Str(validate=validate.Length(max=256), load_default=None)
def validate_uuid(uuid_str: str) -> bool:
"""验证 UUID 格式"""
pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$'
return bool(re.match(pattern, uuid_str.lower()))
def validate_email(email: str) -> bool:
"""验证邮箱格式"""
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
return bool(re.match(pattern, email))
def validate_username(username: str) -> bool:
"""验证用户名格式"""
# 字母、数字、下划线3-80 字符
pattern = r'^[a-zA-Z0-9_]{3,80}$'
return bool(re.match(pattern, username))
def validate_url(url: str) -> bool:
"""验证 URL 格式"""
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
return bool(re.match(pattern, url, re.IGNORECASE))
def sanitize_string(input_str: str, max_length: int = 1000) -> str:
"""清理字符串输入"""
if not input_str:
return ""
# 去除首尾空格
cleaned = input_str.strip()
# 限制长度
if len(cleaned) > max_length:
cleaned = cleaned[:max_length]
# 去除危险字符
cleaned = cleaned.replace('\x00', '')
return cleaned
class ValidationUtils:
"""验证工具类"""
@staticmethod
def validate_json(data: dict, schema: Schema) -> tuple:
"""验证 JSON 数据"""
try:
result = schema.load(data)
return True, result, None
except ValidationError as e:
return False, None, e.messages
@staticmethod
def validate_pagination(page: int, per_page: int, max_per_page: int = 100) -> tuple:
"""验证分页参数"""
if page < 1:
page = 1
if per_page < 1:
per_page = 20
if per_page > max_per_page:
per_page = max_per_page
return page, per_page