feat: Phase 3 - 工具层 + 测试 + 数据库迁移

This commit is contained in:
2026-03-14 20:14:59 +08:00
parent 1836d118fe
commit 6bafd21e02
14 changed files with 1191 additions and 0 deletions

View File

@@ -11,6 +11,46 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
---
## [0.5.0] - 2026-03-14
### Added
#### 🔧 工具层实现 (Phase 3)
- **validators.py** - 输入验证工具
- Marshmallow Schema 验证
- UUID/Email/用户名/URL 验证
- 字符串清理
- **security.py** - 安全工具
- Token/密码哈希
- API Key 生成
- RateLimiter 限流器
- IPWhitelist IP 白名单
- **helpers.py** - 辅助函数
- 日期时间格式化
- 分页辅助
- JSON 安全解析
#### 🧪 测试模块
- **conftest.py** - pytest 配置
- **test_auth.py** - 认证 API 单元测试
- **test_scheduler.py** - 调度器单元测试
- **test_message_queue.py** - 消息队列单元测试
#### 🗄️ 数据库迁移
- **Alembic 配置** - 数据库迁移工具
- **初始迁移脚本** - 创建所有数据表
### Changed
- **requirements.txt** - 添加 gevent 依赖(用于 gunicorn
---
## [0.4.0] - 2026-03-14
### Added

View File

@@ -1,3 +1,81 @@
"""
工具模块
"""
from .validators import (
UserRegistrationSchema,
UserLoginSchema,
SessionCreateSchema,
MessageSendSchema,
AgentRegistrationSchema,
GatewayRegistrationSchema,
validate_uuid,
validate_email,
validate_username,
validate_url,
sanitize_string,
ValidationUtils,
)
from .security import (
generate_token,
hash_password,
verify_password,
hash_token,
verify_token_hash,
generate_api_key,
secure_compare,
mask_sensitive_data,
RateLimiter,
IPWhitelist,
)
from .helpers import (
format_datetime,
parse_datetime,
format_duration,
truncate_string,
safe_json_loads,
safe_json_dumps,
generate_session_title,
calculate_timeout,
merge_dicts,
filter_none_values,
PaginationHelper,
)
__all__ = [
# Validators
'UserRegistrationSchema',
'UserLoginSchema',
'SessionCreateSchema',
'MessageSendSchema',
'AgentRegistrationSchema',
'GatewayRegistrationSchema',
'validate_uuid',
'validate_email',
'validate_username',
'validate_url',
'sanitize_string',
'ValidationUtils',
# Security
'generate_token',
'hash_password',
'verify_password',
'hash_token',
'verify_token_hash',
'generate_api_key',
'secure_compare',
'mask_sensitive_data',
'RateLimiter',
'IPWhitelist',
# Helpers
'format_datetime',
'parse_datetime',
'format_duration',
'truncate_string',
'safe_json_loads',
'safe_json_dumps',
'generate_session_title',
'calculate_timeout',
'merge_dicts',
'filter_none_values',
'PaginationHelper',
]

113
app/utils/helpers.py Normal file
View File

@@ -0,0 +1,113 @@
"""
辅助函数
"""
from datetime import datetime, timedelta
from typing import Optional, Any
import json
def format_datetime(dt: Optional[datetime]) -> Optional[str]:
"""格式化日期时间"""
if not dt:
return None
return dt.isoformat()
def parse_datetime(dt_str: str) -> Optional[datetime]:
"""解析日期时间字符串"""
if not dt_str:
return None
try:
return datetime.fromisoformat(dt_str)
except:
return None
def format_duration(seconds: int) -> str:
"""格式化时长"""
if seconds < 60:
return f"{seconds}s"
elif seconds < 3600:
return f"{seconds // 60}m"
elif seconds < 86400:
return f"{seconds // 3600}h"
else:
return f"{seconds // 86400}d"
def truncate_string(s: str, max_length: int = 100, suffix: str = "...") -> str:
"""截断字符串"""
if len(s) <= max_length:
return s
return s[:max_length - len(suffix)] + suffix
def safe_json_loads(data: str, default: Any = None) -> Any:
"""安全的 JSON 解析"""
try:
return json.loads(data)
except:
return default
def safe_json_dumps(data: Any, default: str = "{}") -> str:
"""安全的 JSON 序列化"""
try:
return json.dumps(data, ensure_ascii=False)
except:
return default
def generate_session_title() -> str:
"""生成默认会话标题"""
return f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
def calculate_timeout(start_time: datetime, timeout_seconds: int) -> bool:
"""检查是否超时"""
if not start_time:
return True
elapsed = (datetime.utcnow() - start_time).total_seconds()
return elapsed > timeout_seconds
def merge_dicts(base: dict, override: dict) -> dict:
"""合并字典"""
result = base.copy()
result.update(override)
return result
def filter_none_values(data: dict) -> dict:
"""过滤字典中的 None 值"""
return {k: v for k, v in data.items() if v is not None}
class PaginationHelper:
"""分页辅助类"""
@staticmethod
def paginate(query, page: int = 1, per_page: int = 20):
"""分页查询"""
if page < 1:
page = 1
if per_page < 1:
per_page = 20
if per_page > 100:
per_page = 100
total = query.count()
items = query.offset((page - 1) * per_page).limit(per_page).all()
pages = (total + per_page - 1) // per_page
return {
'items': items,
'total': total,
'page': page,
'per_page': per_page,
'pages': pages,
'has_next': page < pages,
'has_prev': page > 1,
}

121
app/utils/security.py Normal file
View File

@@ -0,0 +1,121 @@
"""
安全工具
"""
import hashlib
import hmac
import secrets
import bcrypt
from typing import Optional
def generate_token(length: int = 32) -> str:
"""生成随机 Token"""
return secrets.token_urlsafe(length)
def hash_password(password: str) -> str:
"""密码哈希"""
return bcrypt.hashpw(
password.encode('utf-8'),
bcrypt.gensalt()
).decode('utf-8')
def verify_password(password: str, password_hash: str) -> bool:
"""验证密码"""
return bcrypt.checkpw(
password.encode('utf-8'),
password_hash.encode('utf-8')
)
def hash_token(token: str) -> str:
"""Token 哈希(用于存储)"""
return hashlib.sha256(token.encode()).hexdigest()
def verify_token_hash(token: str, token_hash: str) -> bool:
"""验证 Token 哈希"""
return hash_token(token) == token_hash
def generate_api_key() -> str:
"""生成 API Key"""
return f"pit_{secrets.token_urlsafe(32)}"
def secure_compare(val1: str, val2: str) -> bool:
"""安全字符串比较(防时序攻击)"""
return hmac.compare_digest(val1.encode(), val2.encode())
def mask_sensitive_data(data: str, visible_chars: int = 4) -> str:
"""脱敏处理"""
if len(data) <= visible_chars * 2:
return '*' * len(data)
return data[:visible_chars] + '*' * (len(data) - visible_chars * 2) + data[-visible_chars:]
class RateLimiter:
"""简单内存限流器"""
def __init__(self):
self._storage = {}
def is_allowed(self, key: str, limit: int = 100, window: int = 60) -> bool:
"""检查是否允许请求"""
from time import time
now = time()
window_start = now - window
# 获取该 key 的请求记录
requests = self._storage.get(key, [])
# 清理过期记录
requests = [t for t in requests if t > window_start]
# 检查是否超过限制
if len(requests) >= limit:
self._storage[key] = requests
return False
# 记录新请求
requests.append(now)
self._storage[key] = requests
return True
class IPWhitelist:
"""IP 白名单"""
def __init__(self, allowed_ips: list = None):
self.allowed_ips = set(allowed_ips or [])
self.allow_all = '*' in self.allowed_ips
def is_allowed(self, ip: str) -> bool:
"""检查 IP 是否允许"""
if self.allow_all:
return True
# 支持 CIDR 格式
for allowed in self.allowed_ips:
if '/' in allowed:
if self._ip_in_cidr(ip, allowed):
return True
elif ip == allowed:
return True
return False
def _ip_in_cidr(self, ip: str, cidr: str) -> bool:
"""检查 IP 是否在 CIDR 范围内"""
try:
import ipaddress
network = ipaddress.ip_network(cidr, strict=False)
address = ipaddress.ip_address(ip)
return address in network
except:
return False

120
app/utils/validators.py Normal file
View File

@@ -0,0 +1,120 @@
"""
输入验证工具
"""
from typing import Optional
import re
from marshmallow import Schema, fields, validate, ValidationError
class UserRegistrationSchema(Schema):
"""用户注册验证"""
username = fields.Str(required=True, validate=validate.Length(min=3, max=80))
email = fields.Email(required=True)
password = fields.Str(required=True, validate=validate.Length(min=6, max=128))
nickname = fields.Str(validate=validate.Length(max=80), load_default=None)
class UserLoginSchema(Schema):
"""用户登录验证"""
username = fields.Str(required=True)
password = fields.Str(required=True)
class SessionCreateSchema(Schema):
"""会话创建验证"""
title = fields.Str(validate=validate.Length(max=200), load_default=None)
agent_id = fields.Str(validate=validate.Length(max=36), load_default=None)
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
class MessageSendSchema(Schema):
"""消息发送验证"""
session_id = fields.Str(required=True, validate=validate.Length(max=36))
content = fields.Str(required=True, validate=validate.Length(min=1, max=10000))
message_type = fields.Str(validate=validate.OneOf(['text', 'media', 'system']), load_default='text')
reply_to = fields.Str(validate=validate.Length(max=36), load_default=None)
class AgentRegistrationSchema(Schema):
"""Agent 注册验证"""
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
gateway_id = fields.Str(validate=validate.Length(max=36), load_default=None)
model = fields.Str(validate=validate.Length(max=80), load_default=None)
capabilities = fields.List(fields.Str(), load_default=[])
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
weight = fields.Int(validate=validate.Range(min=1, max=100), load_default=10)
class GatewayRegistrationSchema(Schema):
"""Gateway 注册验证"""
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
url = fields.Str(required=True, validate=validate.Length(max=256))
token = fields.Str(validate=validate.Length(max=256), load_default=None)
def validate_uuid(uuid_str: str) -> bool:
"""验证 UUID 格式"""
pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$'
return bool(re.match(pattern, uuid_str.lower()))
def validate_email(email: str) -> bool:
"""验证邮箱格式"""
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
return bool(re.match(pattern, email))
def validate_username(username: str) -> bool:
"""验证用户名格式"""
# 字母、数字、下划线3-80 字符
pattern = r'^[a-zA-Z0-9_]{3,80}$'
return bool(re.match(pattern, username))
def validate_url(url: str) -> bool:
"""验证 URL 格式"""
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
return bool(re.match(pattern, url, re.IGNORECASE))
def sanitize_string(input_str: str, max_length: int = 1000) -> str:
"""清理字符串输入"""
if not input_str:
return ""
# 去除首尾空格
cleaned = input_str.strip()
# 限制长度
if len(cleaned) > max_length:
cleaned = cleaned[:max_length]
# 去除危险字符
cleaned = cleaned.replace('\x00', '')
return cleaned
class ValidationUtils:
"""验证工具类"""
@staticmethod
def validate_json(data: dict, schema: Schema) -> tuple:
"""验证 JSON 数据"""
try:
result = schema.load(data)
return True, result, None
except ValidationError as e:
return False, None, e.messages
@staticmethod
def validate_pagination(page: int, per_page: int, max_per_page: int = 100) -> tuple:
"""验证分页参数"""
if page < 1:
page = 1
if per_page < 1:
per_page = 20
if per_page > max_per_page:
per_page = max_per_page
return page, per_page

48
migrations/alembic.ini Normal file
View File

@@ -0,0 +1,48 @@
[alembic]
# 脚本位置
script_location = migrations
# 模板文件
template_file =
# 最大保留版本数
max_num = 10
# 数据库 URL运行时覆盖
sqlalchemy.url = sqlite:///pit_router.db
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

68
migrations/env.py Normal file
View File

@@ -0,0 +1,68 @@
"""
Alembic 环境配置
"""
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
import os
import sys
# 添加项目路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.extensions import db
from app.config import Config
# Alembic Config 对象
config = context.config
# 配置日志
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# 模型元数据
target_metadata = db.Model.metadata
def get_url():
"""获取数据库 URL"""
return os.getenv('DATABASE_URL', 'sqlite:///pit_router.db')
def run_migrations_offline() -> None:
"""离线迁移"""
url = get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""在线迁移"""
configuration = config.get_section(config.config_ini_section)
configuration["sqlalchemy.url"] = get_url()
connectable = engine_from_config(
configuration,
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,134 @@
"""
初始化迁移脚本
Revision ID: initial
Revises:
Create Date: 2026-03-14
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers
revision = 'initial'
down_revision = None
branch_labels = None
depends_on = None
def upgrade():
"""创建表"""
# 用户表
op.create_table(
'users',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('username', sa.String(80), unique=True, nullable=False),
sa.Column('password_hash', sa.String(256), nullable=False),
sa.Column('email', sa.String(120), unique=True, nullable=False),
sa.Column('nickname', sa.String(80)),
sa.Column('role', sa.String(20), default='user'),
sa.Column('status', sa.String(20), default='active'),
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
sa.Column('last_login_at', sa.DateTime),
)
# Gateway 表
op.create_table(
'gateways',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('name', sa.String(80), unique=True, nullable=False),
sa.Column('url', sa.String(256), nullable=False),
sa.Column('token_hash', sa.String(256)),
sa.Column('status', sa.String(20), default='offline'),
sa.Column('agent_count', sa.Integer, default=0),
sa.Column('connection_limit', sa.Integer, default=10),
sa.Column('heartbeat_interval', sa.Integer, default=60),
sa.Column('allowed_ips', sa.JSON),
sa.Column('last_heartbeat', sa.DateTime),
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
)
# Agent 表
op.create_table(
'agents',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('name', sa.String(80), nullable=False),
sa.Column('display_name', sa.String(80)),
sa.Column('gateway_id', sa.String(36), sa.ForeignKey('gateways.id')),
sa.Column('socket_id', sa.String(100)),
sa.Column('model', sa.String(80)),
sa.Column('capabilities', sa.JSON),
sa.Column('status', sa.String(20), default='offline'),
sa.Column('priority', sa.Integer, default=5),
sa.Column('weight', sa.Integer, default=10),
sa.Column('connection_limit', sa.Integer, default=5),
sa.Column('current_sessions', sa.Integer, default=0),
sa.Column('last_heartbeat', sa.DateTime),
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
)
# 会话表
op.create_table(
'sessions',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
sa.Column('primary_agent_id', sa.String(36), sa.ForeignKey('agents.id')),
sa.Column('participating_agent_ids', sa.JSON),
sa.Column('user_socket_id', sa.String(100)),
sa.Column('title', sa.String(200)),
sa.Column('channel_type', sa.String(20), default='web'),
sa.Column('status', sa.String(20), default='active'),
sa.Column('message_count', sa.Integer, default=0),
sa.Column('unread_count', sa.Integer, default=0),
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
sa.Column('updated_at', sa.DateTime),
sa.Column('last_active_at', sa.DateTime),
)
# 消息表
op.create_table(
'messages',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('session_id', sa.String(36), sa.ForeignKey('sessions.id'), nullable=False),
sa.Column('sender_type', sa.String(20), nullable=False),
sa.Column('sender_id', sa.String(36), nullable=False),
sa.Column('message_type', sa.String(20), default='text'),
sa.Column('content', sa.Text),
sa.Column('content_type', sa.String(20), default='markdown'),
sa.Column('reply_to', sa.String(36)),
sa.Column('status', sa.String(20), default='sent'),
sa.Column('ack_status', sa.String(20), default='pending'),
sa.Column('retry_count', sa.Integer, default=0),
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
sa.Column('delivered_at', sa.DateTime),
)
# 连接表
op.create_table(
'connections',
sa.Column('id', sa.String(36), primary_key=True),
sa.Column('socket_id', sa.String(100), unique=True, nullable=False),
sa.Column('connection_type', sa.String(20), nullable=False),
sa.Column('entity_id', sa.String(36), nullable=False),
sa.Column('entity_type', sa.String(20), nullable=False),
sa.Column('ip_address', sa.String(45)),
sa.Column('user_agent', sa.String(500)),
sa.Column('status', sa.String(20), default='connected'),
sa.Column('auth_token', sa.String(500)),
sa.Column('connected_at', sa.DateTime, default=sa.func.now()),
sa.Column('last_activity', sa.DateTime),
sa.Column('disconnected_at', sa.DateTime),
)
# 创建索引
op.create_index('ix_messages_session_id', 'messages', ['session_id'])
op.create_index('ix_sessions_user_id', 'sessions', ['user_id'])
op.create_index('ix_sessions_status', 'sessions', ['status'])
def downgrade():
"""删除表"""
op.drop_table('connections')
op.drop_table('messages')
op.drop_table('sessions')
op.drop_table('agents')
op.drop_table('gateways')
op.drop_table('users')

View File

@@ -7,6 +7,7 @@ Flask-Login==0.6.3
Flask-Migrate==4.0.5
Flask-SocketIO==5.3.6
Flask-SQLAlchemy==3.1.1
gevent==24.2.1
gunicorn==21.2.0
psycopg2-binary==2.9.9
PyJWT==2.8.0

3
tests/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
测试模块初始化
"""

35
tests/conftest.py Normal file
View File

@@ -0,0 +1,35 @@
"""
pytest 配置
"""
import pytest
import sys
import os
# 添加项目根目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@pytest.fixture
def app():
"""创建测试应用"""
from app import create_app
app = create_app('testing')
with app.app_context():
from app.extensions import db
db.create_all()
yield app
db.session.remove()
db.drop_all()
@pytest.fixture
def client(app):
"""创建测试客户端"""
return app.test_client()
@pytest.fixture
def runner(app):
"""创建测试运行器"""
return app.test_cli_runner()

158
tests/test_auth.py Normal file
View File

@@ -0,0 +1,158 @@
"""
认证 API 单元测试
"""
import pytest
from app import create_app
from app.extensions import db
from app.models import User
from app.utils.security import hash_password
@pytest.fixture
def app():
"""创建测试应用"""
app = create_app('testing')
with app.app_context():
db.create_all()
yield app
db.drop_all()
@pytest.fixture
def client(app):
"""创建测试客户端"""
return app.test_client()
class TestAuthAPI:
"""认证 API 测试"""
def test_register_success(self, client):
"""测试用户注册成功"""
response = client.post('/api/auth/register', json={
'username': 'testuser',
'email': 'test@example.com',
'password': 'password123',
'nickname': 'Test User'
})
assert response.status_code == 201
data = response.get_json()
assert data['user']['username'] == 'testuser'
assert data['user']['email'] == 'test@example.com'
def test_register_duplicate_username(self, client, app):
"""测试重复用户名"""
# 创建已存在的用户
with app.app_context():
user = User(
username='testuser',
email='existing@example.com',
password_hash=hash_password('password')
)
db.session.add(user)
db.session.commit()
# 尝试注册相同用户名
response = client.post('/api/auth/register', json={
'username': 'testuser',
'email': 'new@example.com',
'password': 'password123'
})
assert response.status_code == 400
def test_register_missing_fields(self, client):
"""测试缺少必填字段"""
response = client.post('/api/auth/register', json={
'username': 'testuser'
})
assert response.status_code == 400
def test_login_success(self, client, app):
"""测试登录成功"""
# 创建用户
with app.app_context():
user = User(
username='testuser',
email='test@example.com',
password_hash=hash_password('password123')
)
db.session.add(user)
db.session.commit()
# 登录
response = client.post('/api/auth/login', json={
'username': 'testuser',
'password': 'password123'
})
assert response.status_code == 200
data = response.get_json()
assert 'access_token' in data
assert 'refresh_token' in data
def test_login_invalid_credentials(self, client):
"""测试无效凭据"""
response = client.post('/api/auth/login', json={
'username': 'nonexistent',
'password': 'wrong'
})
assert response.status_code == 401
def test_verify_token(self, client, app):
"""测试 Token 验证"""
# 创建用户
with app.app_context():
user = User(
username='testuser',
email='test@example.com',
password_hash=hash_password('password123')
)
db.session.add(user)
db.session.commit()
# 登录获取 token
login_response = client.post('/api/auth/login', json={
'username': 'testuser',
'password': 'password123'
})
token = login_response.get_json()['access_token']
# 验证 token
response = client.post('/api/auth/verify', headers={
'Authorization': f'Bearer {token}'
})
assert response.status_code == 200
data = response.get_json()
assert data['valid'] is True
class TestValidation:
"""验证测试"""
def test_validate_email(self, client):
"""测试邮箱验证"""
from app.utils.validators import validate_email
assert validate_email('test@example.com') is True
assert validate_email('invalid-email') is False
def test_validate_username(self, client):
"""测试用户名验证"""
from app.utils.validators import validate_username
assert validate_username('testuser') is True
assert validate_username('ab') is False # 太短
assert validate_username('test-user') is True
assert validate_username('test@user') is False # 包含非法字符
def test_validate_uuid(self, client):
"""测试 UUID 验证"""
from app.utils.validators import validate_uuid
assert validate_uuid('550e8400-e29b-41d4-a716-446655440000') is True
assert validate_uuid('invalid-uuid') is False

View File

@@ -0,0 +1,96 @@
"""
消息队列单元测试
"""
import pytest
from app.services.message_queue import MessageQueue
from unittest.mock import MagicMock, patch
class TestMessageQueue:
"""消息队列测试"""
@pytest.fixture
def mock_redis(self):
"""模拟 Redis 客户端"""
with patch('app.services.message_queue.redis_client') as mock:
mock.hset.return_value = 1
mock.rpush.return_value = 1
mock.lpop.return_value = None
mock.hgetall.return_value = {}
mock.delete.return_value = 1
mock.llen.return_value = 0
yield mock
def test_enqueue(self, mock_redis):
"""测试消息入队"""
message = MagicMock()
message.id = 'msg-123'
message.session_id = 'session-123'
message.sender_type = 'user'
message.sender_id = 'user-123'
message.content = 'Hello'
message.status = 'sent'
message.retry_count = 0
message.created_at = None
result = MessageQueue.enqueue(message)
assert result is True
mock_redis.hset.assert_called()
mock_redis.rpush.assert_called()
def test_dequeue_empty(self, mock_redis):
"""测试空队列出队"""
mock_redis.lpop.return_value = None
result = MessageQueue.dequeue()
assert result is None
def test_dequeue_success(self, mock_redis):
"""测试成功出队"""
mock_redis.lpop.return_value = 'msg-123'
mock_redis.hgetall.return_value = {
'id': 'msg-123',
'session_id': 'session-123',
'content': 'Hello',
'retry_count': '0',
}
result = MessageQueue.dequeue()
assert result is not None
assert result['id'] == 'msg-123'
def test_ack(self, mock_redis):
"""测试消息确认"""
result = MessageQueue.ack('msg-123')
assert result is True
mock_redis.delete.assert_called()
def test_retry_within_limit(self, mock_redis):
"""测试重试(未超限)"""
mock_redis.hget.return_value = '1'
result = MessageQueue.retry('msg-123')
assert result is True
mock_redis.hset.assert_called()
mock_redis.rpush.assert_called()
def test_retry_exceed_limit(self, mock_redis):
"""测试重试(超限)"""
mock_redis.hget.return_value = '3'
result = MessageQueue.retry('msg-123')
assert result is False
def test_get_pending_count(self, mock_redis):
"""测试获取待处理数量"""
mock_redis.llen.return_value = 5
count = MessageQueue.get_pending_count()
assert count == 5

176
tests/test_scheduler.py Normal file
View File

@@ -0,0 +1,176 @@
"""
调度器单元测试
"""
import pytest
from app.services.scheduler import (
AgentScheduler,
RoundRobinScheduler,
WeightedRoundRobinScheduler,
LeastConnectionsScheduler,
LeastResponseTimeScheduler,
CapabilityMatchScheduler,
)
from datetime import datetime, timedelta
class MockAgent:
"""模拟 Agent"""
def __init__(self, id, status='online', weight=10, priority=5,
current_sessions=0, connection_limit=5, capabilities=None,
last_heartbeat=None):
self.id = id
self.status = status
self.weight = weight
self.priority = priority
self.current_sessions = current_sessions
self.connection_limit = connection_limit
self.capabilities = capabilities or []
self.last_heartbeat = last_heartbeat
class TestRoundRobinScheduler:
"""轮询调度器测试"""
def test_select_agent(self):
"""测试选择 Agent"""
scheduler = RoundRobinScheduler()
agents = [
MockAgent(id='agent1'),
MockAgent(id='agent2'),
MockAgent(id='agent3'),
]
# 轮询选择
agent1 = scheduler.select_agent(agents)
agent2 = scheduler.select_agent(agents)
agent3 = scheduler.select_agent(agents)
agent4 = scheduler.select_agent(agents) # 循环回第一个
assert agent1.id == 'agent1'
assert agent2.id == 'agent2'
assert agent3.id == 'agent3'
assert agent4.id == 'agent1'
def test_no_online_agents(self):
"""测试无在线 Agent"""
scheduler = RoundRobinScheduler()
agents = [
MockAgent(id='agent1', status='offline'),
MockAgent(id='agent2', status='offline'),
]
agent = scheduler.select_agent(agents)
assert agent is None
class TestWeightedRoundRobinScheduler:
"""加权轮询调度器测试"""
def test_weight_distribution(self):
"""测试权重分布"""
scheduler = WeightedRoundRobinScheduler()
agents = [
MockAgent(id='agent1', weight=3),
MockAgent(id='agent2', weight=1),
]
# 统计选择次数
counts = {'agent1': 0, 'agent2': 0}
for _ in range(100):
agent = scheduler.select_agent(agents)
counts[agent.id] += 1
# agent1 权重 3agent2 权重 1比例应该约 3:1
assert counts['agent1'] > counts['agent2']
def test_zero_weight(self):
"""测试零权重"""
scheduler = WeightedRoundRobinScheduler()
agents = [
MockAgent(id='agent1', weight=0),
MockAgent(id='agent2', weight=0),
]
# 权重都为 0 时返回第一个
agent = scheduler.select_agent(agents)
assert agent is not None
class TestLeastConnectionsScheduler:
"""最少连接调度器测试"""
def test_select_least_connections(self):
"""测试选择最少连接"""
scheduler = LeastConnectionsScheduler()
agents = [
MockAgent(id='agent1', current_sessions=5),
MockAgent(id='agent2', current_sessions=2),
MockAgent(id='agent3', current_sessions=8),
]
agent = scheduler.select_agent(agents)
assert agent.id == 'agent2'
def test_all_at_limit(self):
"""测试所有 Agent 都达上限"""
scheduler = LeastConnectionsScheduler()
agents = [
MockAgent(id='agent1', current_sessions=5, connection_limit=5),
MockAgent(id='agent2', current_sessions=5, connection_limit=5),
]
agent = scheduler.select_agent(agents)
assert agent is None
class TestCapabilityMatchScheduler:
"""能力匹配调度器测试"""
def test_match_capabilities(self):
"""测试能力匹配"""
scheduler = CapabilityMatchScheduler()
agents = [
MockAgent(id='agent1', capabilities=['chat', 'code']),
MockAgent(id='agent2', capabilities=['chat']),
MockAgent(id='agent3', capabilities=['code', 'translate']),
]
# 需要 code 能力
agent = scheduler.select_agent(agents, {'capabilities': ['code']})
assert agent.id in ['agent1', 'agent3']
def test_no_matching_capabilities(self):
"""测试无匹配能力"""
scheduler = CapabilityMatchScheduler()
agents = [
MockAgent(id='agent1', capabilities=['chat']),
]
# 需要 code 能力但无 Agent 具备
agent = scheduler.select_agent(agents, {'capabilities': ['code']})
# 回退到加权轮询
assert agent is not None
class TestAgentScheduler:
"""Agent 调度器工厂测试"""
def test_available_strategies(self):
"""测试可用策略"""
strategies = AgentScheduler.get_available_strategies()
assert 'round_robin' in strategies
assert 'weighted_round_robin' in strategies
assert 'least_connections' in strategies
assert 'least_response_time' in strategies
assert 'capability_match' in strategies
def test_default_strategy(self):
"""测试默认策略"""
scheduler = AgentScheduler()
assert scheduler.get_strategy() == 'weighted_round_robin'
def test_custom_strategy(self):
"""测试自定义策略"""
scheduler = AgentScheduler('round_robin')
assert scheduler.get_strategy() == 'round_robin'