feat: Phase 3 - 工具层 + 测试 + 数据库迁移
This commit is contained in:
@@ -1,3 +1,81 @@
|
||||
"""
|
||||
工具模块
|
||||
"""
|
||||
from .validators import (
|
||||
UserRegistrationSchema,
|
||||
UserLoginSchema,
|
||||
SessionCreateSchema,
|
||||
MessageSendSchema,
|
||||
AgentRegistrationSchema,
|
||||
GatewayRegistrationSchema,
|
||||
validate_uuid,
|
||||
validate_email,
|
||||
validate_username,
|
||||
validate_url,
|
||||
sanitize_string,
|
||||
ValidationUtils,
|
||||
)
|
||||
from .security import (
|
||||
generate_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
hash_token,
|
||||
verify_token_hash,
|
||||
generate_api_key,
|
||||
secure_compare,
|
||||
mask_sensitive_data,
|
||||
RateLimiter,
|
||||
IPWhitelist,
|
||||
)
|
||||
from .helpers import (
|
||||
format_datetime,
|
||||
parse_datetime,
|
||||
format_duration,
|
||||
truncate_string,
|
||||
safe_json_loads,
|
||||
safe_json_dumps,
|
||||
generate_session_title,
|
||||
calculate_timeout,
|
||||
merge_dicts,
|
||||
filter_none_values,
|
||||
PaginationHelper,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Validators
|
||||
'UserRegistrationSchema',
|
||||
'UserLoginSchema',
|
||||
'SessionCreateSchema',
|
||||
'MessageSendSchema',
|
||||
'AgentRegistrationSchema',
|
||||
'GatewayRegistrationSchema',
|
||||
'validate_uuid',
|
||||
'validate_email',
|
||||
'validate_username',
|
||||
'validate_url',
|
||||
'sanitize_string',
|
||||
'ValidationUtils',
|
||||
# Security
|
||||
'generate_token',
|
||||
'hash_password',
|
||||
'verify_password',
|
||||
'hash_token',
|
||||
'verify_token_hash',
|
||||
'generate_api_key',
|
||||
'secure_compare',
|
||||
'mask_sensitive_data',
|
||||
'RateLimiter',
|
||||
'IPWhitelist',
|
||||
# Helpers
|
||||
'format_datetime',
|
||||
'parse_datetime',
|
||||
'format_duration',
|
||||
'truncate_string',
|
||||
'safe_json_loads',
|
||||
'safe_json_dumps',
|
||||
'generate_session_title',
|
||||
'calculate_timeout',
|
||||
'merge_dicts',
|
||||
'filter_none_values',
|
||||
'PaginationHelper',
|
||||
]
|
||||
|
||||
113
app/utils/helpers.py
Normal file
113
app/utils/helpers.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
辅助函数
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Any
|
||||
import json
|
||||
|
||||
|
||||
def format_datetime(dt: Optional[datetime]) -> Optional[str]:
|
||||
"""格式化日期时间"""
|
||||
if not dt:
|
||||
return None
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
def parse_datetime(dt_str: str) -> Optional[datetime]:
|
||||
"""解析日期时间字符串"""
|
||||
if not dt_str:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(dt_str)
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def format_duration(seconds: int) -> str:
|
||||
"""格式化时长"""
|
||||
if seconds < 60:
|
||||
return f"{seconds}s"
|
||||
elif seconds < 3600:
|
||||
return f"{seconds // 60}m"
|
||||
elif seconds < 86400:
|
||||
return f"{seconds // 3600}h"
|
||||
else:
|
||||
return f"{seconds // 86400}d"
|
||||
|
||||
|
||||
def truncate_string(s: str, max_length: int = 100, suffix: str = "...") -> str:
|
||||
"""截断字符串"""
|
||||
if len(s) <= max_length:
|
||||
return s
|
||||
return s[:max_length - len(suffix)] + suffix
|
||||
|
||||
|
||||
def safe_json_loads(data: str, default: Any = None) -> Any:
|
||||
"""安全的 JSON 解析"""
|
||||
try:
|
||||
return json.loads(data)
|
||||
except:
|
||||
return default
|
||||
|
||||
|
||||
def safe_json_dumps(data: Any, default: str = "{}") -> str:
|
||||
"""安全的 JSON 序列化"""
|
||||
try:
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
except:
|
||||
return default
|
||||
|
||||
|
||||
def generate_session_title() -> str:
|
||||
"""生成默认会话标题"""
|
||||
return f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
|
||||
|
||||
|
||||
def calculate_timeout(start_time: datetime, timeout_seconds: int) -> bool:
|
||||
"""检查是否超时"""
|
||||
if not start_time:
|
||||
return True
|
||||
|
||||
elapsed = (datetime.utcnow() - start_time).total_seconds()
|
||||
return elapsed > timeout_seconds
|
||||
|
||||
|
||||
def merge_dicts(base: dict, override: dict) -> dict:
|
||||
"""合并字典"""
|
||||
result = base.copy()
|
||||
result.update(override)
|
||||
return result
|
||||
|
||||
|
||||
def filter_none_values(data: dict) -> dict:
|
||||
"""过滤字典中的 None 值"""
|
||||
return {k: v for k, v in data.items() if v is not None}
|
||||
|
||||
|
||||
class PaginationHelper:
|
||||
"""分页辅助类"""
|
||||
|
||||
@staticmethod
|
||||
def paginate(query, page: int = 1, per_page: int = 20):
|
||||
"""分页查询"""
|
||||
if page < 1:
|
||||
page = 1
|
||||
if per_page < 1:
|
||||
per_page = 20
|
||||
if per_page > 100:
|
||||
per_page = 100
|
||||
|
||||
total = query.count()
|
||||
items = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
pages = (total + per_page - 1) // per_page
|
||||
|
||||
return {
|
||||
'items': items,
|
||||
'total': total,
|
||||
'page': page,
|
||||
'per_page': per_page,
|
||||
'pages': pages,
|
||||
'has_next': page < pages,
|
||||
'has_prev': page > 1,
|
||||
}
|
||||
121
app/utils/security.py
Normal file
121
app/utils/security.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
安全工具
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
import bcrypt
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def generate_token(length: int = 32) -> str:
|
||||
"""生成随机 Token"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""密码哈希"""
|
||||
return bcrypt.hashpw(
|
||||
password.encode('utf-8'),
|
||||
bcrypt.gensalt()
|
||||
).decode('utf-8')
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""验证密码"""
|
||||
return bcrypt.checkpw(
|
||||
password.encode('utf-8'),
|
||||
password_hash.encode('utf-8')
|
||||
)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Token 哈希(用于存储)"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_token_hash(token: str, token_hash: str) -> bool:
|
||||
"""验证 Token 哈希"""
|
||||
return hash_token(token) == token_hash
|
||||
|
||||
|
||||
def generate_api_key() -> str:
|
||||
"""生成 API Key"""
|
||||
return f"pit_{secrets.token_urlsafe(32)}"
|
||||
|
||||
|
||||
def secure_compare(val1: str, val2: str) -> bool:
|
||||
"""安全字符串比较(防时序攻击)"""
|
||||
return hmac.compare_digest(val1.encode(), val2.encode())
|
||||
|
||||
|
||||
def mask_sensitive_data(data: str, visible_chars: int = 4) -> str:
|
||||
"""脱敏处理"""
|
||||
if len(data) <= visible_chars * 2:
|
||||
return '*' * len(data)
|
||||
|
||||
return data[:visible_chars] + '*' * (len(data) - visible_chars * 2) + data[-visible_chars:]
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""简单内存限流器"""
|
||||
|
||||
def __init__(self):
|
||||
self._storage = {}
|
||||
|
||||
def is_allowed(self, key: str, limit: int = 100, window: int = 60) -> bool:
|
||||
"""检查是否允许请求"""
|
||||
from time import time
|
||||
|
||||
now = time()
|
||||
window_start = now - window
|
||||
|
||||
# 获取该 key 的请求记录
|
||||
requests = self._storage.get(key, [])
|
||||
|
||||
# 清理过期记录
|
||||
requests = [t for t in requests if t > window_start]
|
||||
|
||||
# 检查是否超过限制
|
||||
if len(requests) >= limit:
|
||||
self._storage[key] = requests
|
||||
return False
|
||||
|
||||
# 记录新请求
|
||||
requests.append(now)
|
||||
self._storage[key] = requests
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class IPWhitelist:
|
||||
"""IP 白名单"""
|
||||
|
||||
def __init__(self, allowed_ips: list = None):
|
||||
self.allowed_ips = set(allowed_ips or [])
|
||||
self.allow_all = '*' in self.allowed_ips
|
||||
|
||||
def is_allowed(self, ip: str) -> bool:
|
||||
"""检查 IP 是否允许"""
|
||||
if self.allow_all:
|
||||
return True
|
||||
|
||||
# 支持 CIDR 格式
|
||||
for allowed in self.allowed_ips:
|
||||
if '/' in allowed:
|
||||
if self._ip_in_cidr(ip, allowed):
|
||||
return True
|
||||
elif ip == allowed:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _ip_in_cidr(self, ip: str, cidr: str) -> bool:
|
||||
"""检查 IP 是否在 CIDR 范围内"""
|
||||
try:
|
||||
import ipaddress
|
||||
network = ipaddress.ip_network(cidr, strict=False)
|
||||
address = ipaddress.ip_address(ip)
|
||||
return address in network
|
||||
except:
|
||||
return False
|
||||
120
app/utils/validators.py
Normal file
120
app/utils/validators.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
输入验证工具
|
||||
"""
|
||||
from typing import Optional
|
||||
import re
|
||||
from marshmallow import Schema, fields, validate, ValidationError
|
||||
|
||||
|
||||
class UserRegistrationSchema(Schema):
|
||||
"""用户注册验证"""
|
||||
username = fields.Str(required=True, validate=validate.Length(min=3, max=80))
|
||||
email = fields.Email(required=True)
|
||||
password = fields.Str(required=True, validate=validate.Length(min=6, max=128))
|
||||
nickname = fields.Str(validate=validate.Length(max=80), load_default=None)
|
||||
|
||||
|
||||
class UserLoginSchema(Schema):
|
||||
"""用户登录验证"""
|
||||
username = fields.Str(required=True)
|
||||
password = fields.Str(required=True)
|
||||
|
||||
|
||||
class SessionCreateSchema(Schema):
|
||||
"""会话创建验证"""
|
||||
title = fields.Str(validate=validate.Length(max=200), load_default=None)
|
||||
agent_id = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||||
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
|
||||
|
||||
|
||||
class MessageSendSchema(Schema):
|
||||
"""消息发送验证"""
|
||||
session_id = fields.Str(required=True, validate=validate.Length(max=36))
|
||||
content = fields.Str(required=True, validate=validate.Length(min=1, max=10000))
|
||||
message_type = fields.Str(validate=validate.OneOf(['text', 'media', 'system']), load_default='text')
|
||||
reply_to = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||||
|
||||
|
||||
class AgentRegistrationSchema(Schema):
|
||||
"""Agent 注册验证"""
|
||||
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
|
||||
gateway_id = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||||
model = fields.Str(validate=validate.Length(max=80), load_default=None)
|
||||
capabilities = fields.List(fields.Str(), load_default=[])
|
||||
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
|
||||
weight = fields.Int(validate=validate.Range(min=1, max=100), load_default=10)
|
||||
|
||||
|
||||
class GatewayRegistrationSchema(Schema):
|
||||
"""Gateway 注册验证"""
|
||||
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
|
||||
url = fields.Str(required=True, validate=validate.Length(max=256))
|
||||
token = fields.Str(validate=validate.Length(max=256), load_default=None)
|
||||
|
||||
|
||||
def validate_uuid(uuid_str: str) -> bool:
|
||||
"""验证 UUID 格式"""
|
||||
pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$'
|
||||
return bool(re.match(pattern, uuid_str.lower()))
|
||||
|
||||
|
||||
def validate_email(email: str) -> bool:
|
||||
"""验证邮箱格式"""
|
||||
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||||
return bool(re.match(pattern, email))
|
||||
|
||||
|
||||
def validate_username(username: str) -> bool:
|
||||
"""验证用户名格式"""
|
||||
# 字母、数字、下划线,3-80 字符
|
||||
pattern = r'^[a-zA-Z0-9_]{3,80}$'
|
||||
return bool(re.match(pattern, username))
|
||||
|
||||
|
||||
def validate_url(url: str) -> bool:
|
||||
"""验证 URL 格式"""
|
||||
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
|
||||
return bool(re.match(pattern, url, re.IGNORECASE))
|
||||
|
||||
|
||||
def sanitize_string(input_str: str, max_length: int = 1000) -> str:
|
||||
"""清理字符串输入"""
|
||||
if not input_str:
|
||||
return ""
|
||||
|
||||
# 去除首尾空格
|
||||
cleaned = input_str.strip()
|
||||
|
||||
# 限制长度
|
||||
if len(cleaned) > max_length:
|
||||
cleaned = cleaned[:max_length]
|
||||
|
||||
# 去除危险字符
|
||||
cleaned = cleaned.replace('\x00', '')
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
class ValidationUtils:
|
||||
"""验证工具类"""
|
||||
|
||||
@staticmethod
|
||||
def validate_json(data: dict, schema: Schema) -> tuple:
|
||||
"""验证 JSON 数据"""
|
||||
try:
|
||||
result = schema.load(data)
|
||||
return True, result, None
|
||||
except ValidationError as e:
|
||||
return False, None, e.messages
|
||||
|
||||
@staticmethod
|
||||
def validate_pagination(page: int, per_page: int, max_per_page: int = 100) -> tuple:
|
||||
"""验证分页参数"""
|
||||
if page < 1:
|
||||
page = 1
|
||||
if per_page < 1:
|
||||
per_page = 20
|
||||
if per_page > max_per_page:
|
||||
per_page = max_per_page
|
||||
|
||||
return page, per_page
|
||||
Reference in New Issue
Block a user