feat: Phase 3 - 工具层 + 测试 + 数据库迁移
This commit is contained in:
40
CHANGELOG.md
40
CHANGELOG.md
@@ -11,6 +11,46 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
---
|
||||
|
||||
## [0.5.0] - 2026-03-14
|
||||
|
||||
### Added
|
||||
|
||||
#### 🔧 工具层实现 (Phase 3)
|
||||
|
||||
- **validators.py** - 输入验证工具
|
||||
- Marshmallow Schema 验证
|
||||
- UUID/Email/用户名/URL 验证
|
||||
- 字符串清理
|
||||
|
||||
- **security.py** - 安全工具
|
||||
- Token/密码哈希
|
||||
- API Key 生成
|
||||
- RateLimiter 限流器
|
||||
- IPWhitelist IP 白名单
|
||||
|
||||
- **helpers.py** - 辅助函数
|
||||
- 日期时间格式化
|
||||
- 分页辅助
|
||||
- JSON 安全解析
|
||||
|
||||
#### 🧪 测试模块
|
||||
|
||||
- **conftest.py** - pytest 配置
|
||||
- **test_auth.py** - 认证 API 单元测试
|
||||
- **test_scheduler.py** - 调度器单元测试
|
||||
- **test_message_queue.py** - 消息队列单元测试
|
||||
|
||||
#### 🗄️ 数据库迁移
|
||||
|
||||
- **Alembic 配置** - 数据库迁移工具
|
||||
- **初始迁移脚本** - 创建所有数据表
|
||||
|
||||
### Changed
|
||||
|
||||
- **requirements.txt** - 添加 gevent 依赖(用于 gunicorn)
|
||||
|
||||
---
|
||||
|
||||
## [0.4.0] - 2026-03-14
|
||||
|
||||
### Added
|
||||
|
||||
@@ -1,3 +1,81 @@
|
||||
"""
|
||||
工具模块
|
||||
"""
|
||||
from .validators import (
|
||||
UserRegistrationSchema,
|
||||
UserLoginSchema,
|
||||
SessionCreateSchema,
|
||||
MessageSendSchema,
|
||||
AgentRegistrationSchema,
|
||||
GatewayRegistrationSchema,
|
||||
validate_uuid,
|
||||
validate_email,
|
||||
validate_username,
|
||||
validate_url,
|
||||
sanitize_string,
|
||||
ValidationUtils,
|
||||
)
|
||||
from .security import (
|
||||
generate_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
hash_token,
|
||||
verify_token_hash,
|
||||
generate_api_key,
|
||||
secure_compare,
|
||||
mask_sensitive_data,
|
||||
RateLimiter,
|
||||
IPWhitelist,
|
||||
)
|
||||
from .helpers import (
|
||||
format_datetime,
|
||||
parse_datetime,
|
||||
format_duration,
|
||||
truncate_string,
|
||||
safe_json_loads,
|
||||
safe_json_dumps,
|
||||
generate_session_title,
|
||||
calculate_timeout,
|
||||
merge_dicts,
|
||||
filter_none_values,
|
||||
PaginationHelper,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Validators
|
||||
'UserRegistrationSchema',
|
||||
'UserLoginSchema',
|
||||
'SessionCreateSchema',
|
||||
'MessageSendSchema',
|
||||
'AgentRegistrationSchema',
|
||||
'GatewayRegistrationSchema',
|
||||
'validate_uuid',
|
||||
'validate_email',
|
||||
'validate_username',
|
||||
'validate_url',
|
||||
'sanitize_string',
|
||||
'ValidationUtils',
|
||||
# Security
|
||||
'generate_token',
|
||||
'hash_password',
|
||||
'verify_password',
|
||||
'hash_token',
|
||||
'verify_token_hash',
|
||||
'generate_api_key',
|
||||
'secure_compare',
|
||||
'mask_sensitive_data',
|
||||
'RateLimiter',
|
||||
'IPWhitelist',
|
||||
# Helpers
|
||||
'format_datetime',
|
||||
'parse_datetime',
|
||||
'format_duration',
|
||||
'truncate_string',
|
||||
'safe_json_loads',
|
||||
'safe_json_dumps',
|
||||
'generate_session_title',
|
||||
'calculate_timeout',
|
||||
'merge_dicts',
|
||||
'filter_none_values',
|
||||
'PaginationHelper',
|
||||
]
|
||||
|
||||
113
app/utils/helpers.py
Normal file
113
app/utils/helpers.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
辅助函数
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, Any
|
||||
import json
|
||||
|
||||
|
||||
def format_datetime(dt: Optional[datetime]) -> Optional[str]:
|
||||
"""格式化日期时间"""
|
||||
if not dt:
|
||||
return None
|
||||
return dt.isoformat()
|
||||
|
||||
|
||||
def parse_datetime(dt_str: str) -> Optional[datetime]:
|
||||
"""解析日期时间字符串"""
|
||||
if not dt_str:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromisoformat(dt_str)
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def format_duration(seconds: int) -> str:
|
||||
"""格式化时长"""
|
||||
if seconds < 60:
|
||||
return f"{seconds}s"
|
||||
elif seconds < 3600:
|
||||
return f"{seconds // 60}m"
|
||||
elif seconds < 86400:
|
||||
return f"{seconds // 3600}h"
|
||||
else:
|
||||
return f"{seconds // 86400}d"
|
||||
|
||||
|
||||
def truncate_string(s: str, max_length: int = 100, suffix: str = "...") -> str:
|
||||
"""截断字符串"""
|
||||
if len(s) <= max_length:
|
||||
return s
|
||||
return s[:max_length - len(suffix)] + suffix
|
||||
|
||||
|
||||
def safe_json_loads(data: str, default: Any = None) -> Any:
|
||||
"""安全的 JSON 解析"""
|
||||
try:
|
||||
return json.loads(data)
|
||||
except:
|
||||
return default
|
||||
|
||||
|
||||
def safe_json_dumps(data: Any, default: str = "{}") -> str:
|
||||
"""安全的 JSON 序列化"""
|
||||
try:
|
||||
return json.dumps(data, ensure_ascii=False)
|
||||
except:
|
||||
return default
|
||||
|
||||
|
||||
def generate_session_title() -> str:
|
||||
"""生成默认会话标题"""
|
||||
return f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
|
||||
|
||||
|
||||
def calculate_timeout(start_time: datetime, timeout_seconds: int) -> bool:
|
||||
"""检查是否超时"""
|
||||
if not start_time:
|
||||
return True
|
||||
|
||||
elapsed = (datetime.utcnow() - start_time).total_seconds()
|
||||
return elapsed > timeout_seconds
|
||||
|
||||
|
||||
def merge_dicts(base: dict, override: dict) -> dict:
|
||||
"""合并字典"""
|
||||
result = base.copy()
|
||||
result.update(override)
|
||||
return result
|
||||
|
||||
|
||||
def filter_none_values(data: dict) -> dict:
|
||||
"""过滤字典中的 None 值"""
|
||||
return {k: v for k, v in data.items() if v is not None}
|
||||
|
||||
|
||||
class PaginationHelper:
|
||||
"""分页辅助类"""
|
||||
|
||||
@staticmethod
|
||||
def paginate(query, page: int = 1, per_page: int = 20):
|
||||
"""分页查询"""
|
||||
if page < 1:
|
||||
page = 1
|
||||
if per_page < 1:
|
||||
per_page = 20
|
||||
if per_page > 100:
|
||||
per_page = 100
|
||||
|
||||
total = query.count()
|
||||
items = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||
|
||||
pages = (total + per_page - 1) // per_page
|
||||
|
||||
return {
|
||||
'items': items,
|
||||
'total': total,
|
||||
'page': page,
|
||||
'per_page': per_page,
|
||||
'pages': pages,
|
||||
'has_next': page < pages,
|
||||
'has_prev': page > 1,
|
||||
}
|
||||
121
app/utils/security.py
Normal file
121
app/utils/security.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
安全工具
|
||||
"""
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
import bcrypt
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def generate_token(length: int = 32) -> str:
|
||||
"""生成随机 Token"""
|
||||
return secrets.token_urlsafe(length)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""密码哈希"""
|
||||
return bcrypt.hashpw(
|
||||
password.encode('utf-8'),
|
||||
bcrypt.gensalt()
|
||||
).decode('utf-8')
|
||||
|
||||
|
||||
def verify_password(password: str, password_hash: str) -> bool:
|
||||
"""验证密码"""
|
||||
return bcrypt.checkpw(
|
||||
password.encode('utf-8'),
|
||||
password_hash.encode('utf-8')
|
||||
)
|
||||
|
||||
|
||||
def hash_token(token: str) -> str:
|
||||
"""Token 哈希(用于存储)"""
|
||||
return hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def verify_token_hash(token: str, token_hash: str) -> bool:
|
||||
"""验证 Token 哈希"""
|
||||
return hash_token(token) == token_hash
|
||||
|
||||
|
||||
def generate_api_key() -> str:
|
||||
"""生成 API Key"""
|
||||
return f"pit_{secrets.token_urlsafe(32)}"
|
||||
|
||||
|
||||
def secure_compare(val1: str, val2: str) -> bool:
|
||||
"""安全字符串比较(防时序攻击)"""
|
||||
return hmac.compare_digest(val1.encode(), val2.encode())
|
||||
|
||||
|
||||
def mask_sensitive_data(data: str, visible_chars: int = 4) -> str:
|
||||
"""脱敏处理"""
|
||||
if len(data) <= visible_chars * 2:
|
||||
return '*' * len(data)
|
||||
|
||||
return data[:visible_chars] + '*' * (len(data) - visible_chars * 2) + data[-visible_chars:]
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""简单内存限流器"""
|
||||
|
||||
def __init__(self):
|
||||
self._storage = {}
|
||||
|
||||
def is_allowed(self, key: str, limit: int = 100, window: int = 60) -> bool:
|
||||
"""检查是否允许请求"""
|
||||
from time import time
|
||||
|
||||
now = time()
|
||||
window_start = now - window
|
||||
|
||||
# 获取该 key 的请求记录
|
||||
requests = self._storage.get(key, [])
|
||||
|
||||
# 清理过期记录
|
||||
requests = [t for t in requests if t > window_start]
|
||||
|
||||
# 检查是否超过限制
|
||||
if len(requests) >= limit:
|
||||
self._storage[key] = requests
|
||||
return False
|
||||
|
||||
# 记录新请求
|
||||
requests.append(now)
|
||||
self._storage[key] = requests
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class IPWhitelist:
|
||||
"""IP 白名单"""
|
||||
|
||||
def __init__(self, allowed_ips: list = None):
|
||||
self.allowed_ips = set(allowed_ips or [])
|
||||
self.allow_all = '*' in self.allowed_ips
|
||||
|
||||
def is_allowed(self, ip: str) -> bool:
|
||||
"""检查 IP 是否允许"""
|
||||
if self.allow_all:
|
||||
return True
|
||||
|
||||
# 支持 CIDR 格式
|
||||
for allowed in self.allowed_ips:
|
||||
if '/' in allowed:
|
||||
if self._ip_in_cidr(ip, allowed):
|
||||
return True
|
||||
elif ip == allowed:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _ip_in_cidr(self, ip: str, cidr: str) -> bool:
|
||||
"""检查 IP 是否在 CIDR 范围内"""
|
||||
try:
|
||||
import ipaddress
|
||||
network = ipaddress.ip_network(cidr, strict=False)
|
||||
address = ipaddress.ip_address(ip)
|
||||
return address in network
|
||||
except:
|
||||
return False
|
||||
120
app/utils/validators.py
Normal file
120
app/utils/validators.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
输入验证工具
|
||||
"""
|
||||
from typing import Optional
|
||||
import re
|
||||
from marshmallow import Schema, fields, validate, ValidationError
|
||||
|
||||
|
||||
class UserRegistrationSchema(Schema):
|
||||
"""用户注册验证"""
|
||||
username = fields.Str(required=True, validate=validate.Length(min=3, max=80))
|
||||
email = fields.Email(required=True)
|
||||
password = fields.Str(required=True, validate=validate.Length(min=6, max=128))
|
||||
nickname = fields.Str(validate=validate.Length(max=80), load_default=None)
|
||||
|
||||
|
||||
class UserLoginSchema(Schema):
|
||||
"""用户登录验证"""
|
||||
username = fields.Str(required=True)
|
||||
password = fields.Str(required=True)
|
||||
|
||||
|
||||
class SessionCreateSchema(Schema):
|
||||
"""会话创建验证"""
|
||||
title = fields.Str(validate=validate.Length(max=200), load_default=None)
|
||||
agent_id = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||||
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
|
||||
|
||||
|
||||
class MessageSendSchema(Schema):
|
||||
"""消息发送验证"""
|
||||
session_id = fields.Str(required=True, validate=validate.Length(max=36))
|
||||
content = fields.Str(required=True, validate=validate.Length(min=1, max=10000))
|
||||
message_type = fields.Str(validate=validate.OneOf(['text', 'media', 'system']), load_default='text')
|
||||
reply_to = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||||
|
||||
|
||||
class AgentRegistrationSchema(Schema):
|
||||
"""Agent 注册验证"""
|
||||
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
|
||||
gateway_id = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||||
model = fields.Str(validate=validate.Length(max=80), load_default=None)
|
||||
capabilities = fields.List(fields.Str(), load_default=[])
|
||||
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
|
||||
weight = fields.Int(validate=validate.Range(min=1, max=100), load_default=10)
|
||||
|
||||
|
||||
class GatewayRegistrationSchema(Schema):
|
||||
"""Gateway 注册验证"""
|
||||
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
|
||||
url = fields.Str(required=True, validate=validate.Length(max=256))
|
||||
token = fields.Str(validate=validate.Length(max=256), load_default=None)
|
||||
|
||||
|
||||
def validate_uuid(uuid_str: str) -> bool:
|
||||
"""验证 UUID 格式"""
|
||||
pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$'
|
||||
return bool(re.match(pattern, uuid_str.lower()))
|
||||
|
||||
|
||||
def validate_email(email: str) -> bool:
|
||||
"""验证邮箱格式"""
|
||||
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||||
return bool(re.match(pattern, email))
|
||||
|
||||
|
||||
def validate_username(username: str) -> bool:
|
||||
"""验证用户名格式"""
|
||||
# 字母、数字、下划线,3-80 字符
|
||||
pattern = r'^[a-zA-Z0-9_]{3,80}$'
|
||||
return bool(re.match(pattern, username))
|
||||
|
||||
|
||||
def validate_url(url: str) -> bool:
|
||||
"""验证 URL 格式"""
|
||||
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
|
||||
return bool(re.match(pattern, url, re.IGNORECASE))
|
||||
|
||||
|
||||
def sanitize_string(input_str: str, max_length: int = 1000) -> str:
|
||||
"""清理字符串输入"""
|
||||
if not input_str:
|
||||
return ""
|
||||
|
||||
# 去除首尾空格
|
||||
cleaned = input_str.strip()
|
||||
|
||||
# 限制长度
|
||||
if len(cleaned) > max_length:
|
||||
cleaned = cleaned[:max_length]
|
||||
|
||||
# 去除危险字符
|
||||
cleaned = cleaned.replace('\x00', '')
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
class ValidationUtils:
|
||||
"""验证工具类"""
|
||||
|
||||
@staticmethod
|
||||
def validate_json(data: dict, schema: Schema) -> tuple:
|
||||
"""验证 JSON 数据"""
|
||||
try:
|
||||
result = schema.load(data)
|
||||
return True, result, None
|
||||
except ValidationError as e:
|
||||
return False, None, e.messages
|
||||
|
||||
@staticmethod
|
||||
def validate_pagination(page: int, per_page: int, max_per_page: int = 100) -> tuple:
|
||||
"""验证分页参数"""
|
||||
if page < 1:
|
||||
page = 1
|
||||
if per_page < 1:
|
||||
per_page = 20
|
||||
if per_page > max_per_page:
|
||||
per_page = max_per_page
|
||||
|
||||
return page, per_page
|
||||
48
migrations/alembic.ini
Normal file
48
migrations/alembic.ini
Normal file
@@ -0,0 +1,48 @@
|
||||
[alembic]
|
||||
# 脚本位置
|
||||
script_location = migrations
|
||||
|
||||
# 模板文件
|
||||
template_file =
|
||||
|
||||
# 最大保留版本数
|
||||
max_num = 10
|
||||
|
||||
# 数据库 URL(运行时覆盖)
|
||||
sqlalchemy.url = sqlite:///pit_router.db
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
68
migrations/env.py
Normal file
68
migrations/env.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Alembic 环境配置
|
||||
"""
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import engine_from_config
|
||||
from sqlalchemy import pool
|
||||
from alembic import context
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加项目路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from app.extensions import db
|
||||
from app.config import Config
|
||||
|
||||
# Alembic Config 对象
|
||||
config = context.config
|
||||
|
||||
# 配置日志
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# 模型元数据
|
||||
target_metadata = db.Model.metadata
|
||||
|
||||
def get_url():
|
||||
"""获取数据库 URL"""
|
||||
return os.getenv('DATABASE_URL', 'sqlite:///pit_router.db')
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""离线迁移"""
|
||||
url = get_url()
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""在线迁移"""
|
||||
configuration = config.get_section(config.config_ini_section)
|
||||
configuration["sqlalchemy.url"] = get_url()
|
||||
connectable = engine_from_config(
|
||||
configuration,
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
134
migrations/versions/initial.py
Normal file
134
migrations/versions/initial.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
初始化迁移脚本
|
||||
Revision ID: initial
|
||||
Revises:
|
||||
Create Date: 2026-03-14
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers
|
||||
revision = 'initial'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""创建表"""
|
||||
# 用户表
|
||||
op.create_table(
|
||||
'users',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('username', sa.String(80), unique=True, nullable=False),
|
||||
sa.Column('password_hash', sa.String(256), nullable=False),
|
||||
sa.Column('email', sa.String(120), unique=True, nullable=False),
|
||||
sa.Column('nickname', sa.String(80)),
|
||||
sa.Column('role', sa.String(20), default='user'),
|
||||
sa.Column('status', sa.String(20), default='active'),
|
||||
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
|
||||
sa.Column('last_login_at', sa.DateTime),
|
||||
)
|
||||
|
||||
# Gateway 表
|
||||
op.create_table(
|
||||
'gateways',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('name', sa.String(80), unique=True, nullable=False),
|
||||
sa.Column('url', sa.String(256), nullable=False),
|
||||
sa.Column('token_hash', sa.String(256)),
|
||||
sa.Column('status', sa.String(20), default='offline'),
|
||||
sa.Column('agent_count', sa.Integer, default=0),
|
||||
sa.Column('connection_limit', sa.Integer, default=10),
|
||||
sa.Column('heartbeat_interval', sa.Integer, default=60),
|
||||
sa.Column('allowed_ips', sa.JSON),
|
||||
sa.Column('last_heartbeat', sa.DateTime),
|
||||
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
|
||||
)
|
||||
|
||||
# Agent 表
|
||||
op.create_table(
|
||||
'agents',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('name', sa.String(80), nullable=False),
|
||||
sa.Column('display_name', sa.String(80)),
|
||||
sa.Column('gateway_id', sa.String(36), sa.ForeignKey('gateways.id')),
|
||||
sa.Column('socket_id', sa.String(100)),
|
||||
sa.Column('model', sa.String(80)),
|
||||
sa.Column('capabilities', sa.JSON),
|
||||
sa.Column('status', sa.String(20), default='offline'),
|
||||
sa.Column('priority', sa.Integer, default=5),
|
||||
sa.Column('weight', sa.Integer, default=10),
|
||||
sa.Column('connection_limit', sa.Integer, default=5),
|
||||
sa.Column('current_sessions', sa.Integer, default=0),
|
||||
sa.Column('last_heartbeat', sa.DateTime),
|
||||
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
|
||||
)
|
||||
|
||||
# 会话表
|
||||
op.create_table(
|
||||
'sessions',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
|
||||
sa.Column('primary_agent_id', sa.String(36), sa.ForeignKey('agents.id')),
|
||||
sa.Column('participating_agent_ids', sa.JSON),
|
||||
sa.Column('user_socket_id', sa.String(100)),
|
||||
sa.Column('title', sa.String(200)),
|
||||
sa.Column('channel_type', sa.String(20), default='web'),
|
||||
sa.Column('status', sa.String(20), default='active'),
|
||||
sa.Column('message_count', sa.Integer, default=0),
|
||||
sa.Column('unread_count', sa.Integer, default=0),
|
||||
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
|
||||
sa.Column('updated_at', sa.DateTime),
|
||||
sa.Column('last_active_at', sa.DateTime),
|
||||
)
|
||||
|
||||
# 消息表
|
||||
op.create_table(
|
||||
'messages',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('session_id', sa.String(36), sa.ForeignKey('sessions.id'), nullable=False),
|
||||
sa.Column('sender_type', sa.String(20), nullable=False),
|
||||
sa.Column('sender_id', sa.String(36), nullable=False),
|
||||
sa.Column('message_type', sa.String(20), default='text'),
|
||||
sa.Column('content', sa.Text),
|
||||
sa.Column('content_type', sa.String(20), default='markdown'),
|
||||
sa.Column('reply_to', sa.String(36)),
|
||||
sa.Column('status', sa.String(20), default='sent'),
|
||||
sa.Column('ack_status', sa.String(20), default='pending'),
|
||||
sa.Column('retry_count', sa.Integer, default=0),
|
||||
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
|
||||
sa.Column('delivered_at', sa.DateTime),
|
||||
)
|
||||
|
||||
# 连接表
|
||||
op.create_table(
|
||||
'connections',
|
||||
sa.Column('id', sa.String(36), primary_key=True),
|
||||
sa.Column('socket_id', sa.String(100), unique=True, nullable=False),
|
||||
sa.Column('connection_type', sa.String(20), nullable=False),
|
||||
sa.Column('entity_id', sa.String(36), nullable=False),
|
||||
sa.Column('entity_type', sa.String(20), nullable=False),
|
||||
sa.Column('ip_address', sa.String(45)),
|
||||
sa.Column('user_agent', sa.String(500)),
|
||||
sa.Column('status', sa.String(20), default='connected'),
|
||||
sa.Column('auth_token', sa.String(500)),
|
||||
sa.Column('connected_at', sa.DateTime, default=sa.func.now()),
|
||||
sa.Column('last_activity', sa.DateTime),
|
||||
sa.Column('disconnected_at', sa.DateTime),
|
||||
)
|
||||
|
||||
# 创建索引
|
||||
op.create_index('ix_messages_session_id', 'messages', ['session_id'])
|
||||
op.create_index('ix_sessions_user_id', 'sessions', ['user_id'])
|
||||
op.create_index('ix_sessions_status', 'sessions', ['status'])
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""删除表"""
|
||||
op.drop_table('connections')
|
||||
op.drop_table('messages')
|
||||
op.drop_table('sessions')
|
||||
op.drop_table('agents')
|
||||
op.drop_table('gateways')
|
||||
op.drop_table('users')
|
||||
@@ -7,6 +7,7 @@ Flask-Login==0.6.3
|
||||
Flask-Migrate==4.0.5
|
||||
Flask-SocketIO==5.3.6
|
||||
Flask-SQLAlchemy==3.1.1
|
||||
gevent==24.2.1
|
||||
gunicorn==21.2.0
|
||||
psycopg2-binary==2.9.9
|
||||
PyJWT==2.8.0
|
||||
|
||||
3
tests/__init__.py
Normal file
3
tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
测试模块初始化
|
||||
"""
|
||||
35
tests/conftest.py
Normal file
35
tests/conftest.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
pytest 配置
|
||||
"""
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""创建测试应用"""
|
||||
from app import create_app
|
||||
app = create_app('testing')
|
||||
|
||||
with app.app_context():
|
||||
from app.extensions import db
|
||||
db.create_all()
|
||||
yield app
|
||||
db.session.remove()
|
||||
db.drop_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""创建测试客户端"""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(app):
|
||||
"""创建测试运行器"""
|
||||
return app.test_cli_runner()
|
||||
158
tests/test_auth.py
Normal file
158
tests/test_auth.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
认证 API 单元测试
|
||||
"""
|
||||
import pytest
|
||||
from app import create_app
|
||||
from app.extensions import db
|
||||
from app.models import User
|
||||
from app.utils.security import hash_password
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""创建测试应用"""
|
||||
app = create_app('testing')
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
yield app
|
||||
db.drop_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""创建测试客户端"""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
class TestAuthAPI:
|
||||
"""认证 API 测试"""
|
||||
|
||||
def test_register_success(self, client):
|
||||
"""测试用户注册成功"""
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'testuser',
|
||||
'email': 'test@example.com',
|
||||
'password': 'password123',
|
||||
'nickname': 'Test User'
|
||||
})
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.get_json()
|
||||
assert data['user']['username'] == 'testuser'
|
||||
assert data['user']['email'] == 'test@example.com'
|
||||
|
||||
def test_register_duplicate_username(self, client, app):
|
||||
"""测试重复用户名"""
|
||||
# 创建已存在的用户
|
||||
with app.app_context():
|
||||
user = User(
|
||||
username='testuser',
|
||||
email='existing@example.com',
|
||||
password_hash=hash_password('password')
|
||||
)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
# 尝试注册相同用户名
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'testuser',
|
||||
'email': 'new@example.com',
|
||||
'password': 'password123'
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_register_missing_fields(self, client):
|
||||
"""测试缺少必填字段"""
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'testuser'
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_login_success(self, client, app):
|
||||
"""测试登录成功"""
|
||||
# 创建用户
|
||||
with app.app_context():
|
||||
user = User(
|
||||
username='testuser',
|
||||
email='test@example.com',
|
||||
password_hash=hash_password('password123')
|
||||
)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
# 登录
|
||||
response = client.post('/api/auth/login', json={
|
||||
'username': 'testuser',
|
||||
'password': 'password123'
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert 'access_token' in data
|
||||
assert 'refresh_token' in data
|
||||
|
||||
def test_login_invalid_credentials(self, client):
|
||||
"""测试无效凭据"""
|
||||
response = client.post('/api/auth/login', json={
|
||||
'username': 'nonexistent',
|
||||
'password': 'wrong'
|
||||
})
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_verify_token(self, client, app):
|
||||
"""测试 Token 验证"""
|
||||
# 创建用户
|
||||
with app.app_context():
|
||||
user = User(
|
||||
username='testuser',
|
||||
email='test@example.com',
|
||||
password_hash=hash_password('password123')
|
||||
)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
# 登录获取 token
|
||||
login_response = client.post('/api/auth/login', json={
|
||||
'username': 'testuser',
|
||||
'password': 'password123'
|
||||
})
|
||||
token = login_response.get_json()['access_token']
|
||||
|
||||
# 验证 token
|
||||
response = client.post('/api/auth/verify', headers={
|
||||
'Authorization': f'Bearer {token}'
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert data['valid'] is True
|
||||
|
||||
|
||||
class TestValidation:
|
||||
"""验证测试"""
|
||||
|
||||
def test_validate_email(self, client):
|
||||
"""测试邮箱验证"""
|
||||
from app.utils.validators import validate_email
|
||||
|
||||
assert validate_email('test@example.com') is True
|
||||
assert validate_email('invalid-email') is False
|
||||
|
||||
def test_validate_username(self, client):
|
||||
"""测试用户名验证"""
|
||||
from app.utils.validators import validate_username
|
||||
|
||||
assert validate_username('testuser') is True
|
||||
assert validate_username('ab') is False # 太短
|
||||
assert validate_username('test-user') is True
|
||||
assert validate_username('test@user') is False # 包含非法字符
|
||||
|
||||
def test_validate_uuid(self, client):
|
||||
"""测试 UUID 验证"""
|
||||
from app.utils.validators import validate_uuid
|
||||
|
||||
assert validate_uuid('550e8400-e29b-41d4-a716-446655440000') is True
|
||||
assert validate_uuid('invalid-uuid') is False
|
||||
96
tests/test_message_queue.py
Normal file
96
tests/test_message_queue.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
消息队列单元测试
|
||||
"""
|
||||
import pytest
|
||||
from app.services.message_queue import MessageQueue
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestMessageQueue:
|
||||
"""消息队列测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
"""模拟 Redis 客户端"""
|
||||
with patch('app.services.message_queue.redis_client') as mock:
|
||||
mock.hset.return_value = 1
|
||||
mock.rpush.return_value = 1
|
||||
mock.lpop.return_value = None
|
||||
mock.hgetall.return_value = {}
|
||||
mock.delete.return_value = 1
|
||||
mock.llen.return_value = 0
|
||||
yield mock
|
||||
|
||||
def test_enqueue(self, mock_redis):
|
||||
"""测试消息入队"""
|
||||
message = MagicMock()
|
||||
message.id = 'msg-123'
|
||||
message.session_id = 'session-123'
|
||||
message.sender_type = 'user'
|
||||
message.sender_id = 'user-123'
|
||||
message.content = 'Hello'
|
||||
message.status = 'sent'
|
||||
message.retry_count = 0
|
||||
message.created_at = None
|
||||
|
||||
result = MessageQueue.enqueue(message)
|
||||
|
||||
assert result is True
|
||||
mock_redis.hset.assert_called()
|
||||
mock_redis.rpush.assert_called()
|
||||
|
||||
def test_dequeue_empty(self, mock_redis):
|
||||
"""测试空队列出队"""
|
||||
mock_redis.lpop.return_value = None
|
||||
|
||||
result = MessageQueue.dequeue()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_dequeue_success(self, mock_redis):
|
||||
"""测试成功出队"""
|
||||
mock_redis.lpop.return_value = 'msg-123'
|
||||
mock_redis.hgetall.return_value = {
|
||||
'id': 'msg-123',
|
||||
'session_id': 'session-123',
|
||||
'content': 'Hello',
|
||||
'retry_count': '0',
|
||||
}
|
||||
|
||||
result = MessageQueue.dequeue()
|
||||
|
||||
assert result is not None
|
||||
assert result['id'] == 'msg-123'
|
||||
|
||||
def test_ack(self, mock_redis):
|
||||
"""测试消息确认"""
|
||||
result = MessageQueue.ack('msg-123')
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called()
|
||||
|
||||
def test_retry_within_limit(self, mock_redis):
|
||||
"""测试重试(未超限)"""
|
||||
mock_redis.hget.return_value = '1'
|
||||
|
||||
result = MessageQueue.retry('msg-123')
|
||||
|
||||
assert result is True
|
||||
mock_redis.hset.assert_called()
|
||||
mock_redis.rpush.assert_called()
|
||||
|
||||
def test_retry_exceed_limit(self, mock_redis):
|
||||
"""测试重试(超限)"""
|
||||
mock_redis.hget.return_value = '3'
|
||||
|
||||
result = MessageQueue.retry('msg-123')
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_get_pending_count(self, mock_redis):
|
||||
"""测试获取待处理数量"""
|
||||
mock_redis.llen.return_value = 5
|
||||
|
||||
count = MessageQueue.get_pending_count()
|
||||
|
||||
assert count == 5
|
||||
176
tests/test_scheduler.py
Normal file
176
tests/test_scheduler.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
调度器单元测试
|
||||
"""
|
||||
import pytest
|
||||
from app.services.scheduler import (
|
||||
AgentScheduler,
|
||||
RoundRobinScheduler,
|
||||
WeightedRoundRobinScheduler,
|
||||
LeastConnectionsScheduler,
|
||||
LeastResponseTimeScheduler,
|
||||
CapabilityMatchScheduler,
|
||||
)
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class MockAgent:
|
||||
"""模拟 Agent"""
|
||||
def __init__(self, id, status='online', weight=10, priority=5,
|
||||
current_sessions=0, connection_limit=5, capabilities=None,
|
||||
last_heartbeat=None):
|
||||
self.id = id
|
||||
self.status = status
|
||||
self.weight = weight
|
||||
self.priority = priority
|
||||
self.current_sessions = current_sessions
|
||||
self.connection_limit = connection_limit
|
||||
self.capabilities = capabilities or []
|
||||
self.last_heartbeat = last_heartbeat
|
||||
|
||||
|
||||
class TestRoundRobinScheduler:
|
||||
"""轮询调度器测试"""
|
||||
|
||||
def test_select_agent(self):
|
||||
"""测试选择 Agent"""
|
||||
scheduler = RoundRobinScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1'),
|
||||
MockAgent(id='agent2'),
|
||||
MockAgent(id='agent3'),
|
||||
]
|
||||
|
||||
# 轮询选择
|
||||
agent1 = scheduler.select_agent(agents)
|
||||
agent2 = scheduler.select_agent(agents)
|
||||
agent3 = scheduler.select_agent(agents)
|
||||
agent4 = scheduler.select_agent(agents) # 循环回第一个
|
||||
|
||||
assert agent1.id == 'agent1'
|
||||
assert agent2.id == 'agent2'
|
||||
assert agent3.id == 'agent3'
|
||||
assert agent4.id == 'agent1'
|
||||
|
||||
def test_no_online_agents(self):
|
||||
"""测试无在线 Agent"""
|
||||
scheduler = RoundRobinScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', status='offline'),
|
||||
MockAgent(id='agent2', status='offline'),
|
||||
]
|
||||
|
||||
agent = scheduler.select_agent(agents)
|
||||
assert agent is None
|
||||
|
||||
|
||||
class TestWeightedRoundRobinScheduler:
|
||||
"""加权轮询调度器测试"""
|
||||
|
||||
def test_weight_distribution(self):
|
||||
"""测试权重分布"""
|
||||
scheduler = WeightedRoundRobinScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', weight=3),
|
||||
MockAgent(id='agent2', weight=1),
|
||||
]
|
||||
|
||||
# 统计选择次数
|
||||
counts = {'agent1': 0, 'agent2': 0}
|
||||
for _ in range(100):
|
||||
agent = scheduler.select_agent(agents)
|
||||
counts[agent.id] += 1
|
||||
|
||||
# agent1 权重 3,agent2 权重 1,比例应该约 3:1
|
||||
assert counts['agent1'] > counts['agent2']
|
||||
|
||||
def test_zero_weight(self):
|
||||
"""测试零权重"""
|
||||
scheduler = WeightedRoundRobinScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', weight=0),
|
||||
MockAgent(id='agent2', weight=0),
|
||||
]
|
||||
|
||||
# 权重都为 0 时返回第一个
|
||||
agent = scheduler.select_agent(agents)
|
||||
assert agent is not None
|
||||
|
||||
|
||||
class TestLeastConnectionsScheduler:
|
||||
"""最少连接调度器测试"""
|
||||
|
||||
def test_select_least_connections(self):
|
||||
"""测试选择最少连接"""
|
||||
scheduler = LeastConnectionsScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', current_sessions=5),
|
||||
MockAgent(id='agent2', current_sessions=2),
|
||||
MockAgent(id='agent3', current_sessions=8),
|
||||
]
|
||||
|
||||
agent = scheduler.select_agent(agents)
|
||||
assert agent.id == 'agent2'
|
||||
|
||||
def test_all_at_limit(self):
|
||||
"""测试所有 Agent 都达上限"""
|
||||
scheduler = LeastConnectionsScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', current_sessions=5, connection_limit=5),
|
||||
MockAgent(id='agent2', current_sessions=5, connection_limit=5),
|
||||
]
|
||||
|
||||
agent = scheduler.select_agent(agents)
|
||||
assert agent is None
|
||||
|
||||
|
||||
class TestCapabilityMatchScheduler:
|
||||
"""能力匹配调度器测试"""
|
||||
|
||||
def test_match_capabilities(self):
|
||||
"""测试能力匹配"""
|
||||
scheduler = CapabilityMatchScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', capabilities=['chat', 'code']),
|
||||
MockAgent(id='agent2', capabilities=['chat']),
|
||||
MockAgent(id='agent3', capabilities=['code', 'translate']),
|
||||
]
|
||||
|
||||
# 需要 code 能力
|
||||
agent = scheduler.select_agent(agents, {'capabilities': ['code']})
|
||||
assert agent.id in ['agent1', 'agent3']
|
||||
|
||||
def test_no_matching_capabilities(self):
|
||||
"""测试无匹配能力"""
|
||||
scheduler = CapabilityMatchScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', capabilities=['chat']),
|
||||
]
|
||||
|
||||
# 需要 code 能力但无 Agent 具备
|
||||
agent = scheduler.select_agent(agents, {'capabilities': ['code']})
|
||||
# 回退到加权轮询
|
||||
assert agent is not None
|
||||
|
||||
|
||||
class TestAgentScheduler:
|
||||
"""Agent 调度器工厂测试"""
|
||||
|
||||
def test_available_strategies(self):
|
||||
"""测试可用策略"""
|
||||
strategies = AgentScheduler.get_available_strategies()
|
||||
|
||||
assert 'round_robin' in strategies
|
||||
assert 'weighted_round_robin' in strategies
|
||||
assert 'least_connections' in strategies
|
||||
assert 'least_response_time' in strategies
|
||||
assert 'capability_match' in strategies
|
||||
|
||||
def test_default_strategy(self):
|
||||
"""测试默认策略"""
|
||||
scheduler = AgentScheduler()
|
||||
assert scheduler.get_strategy() == 'weighted_round_robin'
|
||||
|
||||
def test_custom_strategy(self):
|
||||
"""测试自定义策略"""
|
||||
scheduler = AgentScheduler('round_robin')
|
||||
assert scheduler.get_strategy() == 'round_robin'
|
||||
Reference in New Issue
Block a user