feat: Phase 3 - 工具层 + 测试 + 数据库迁移

This commit is contained in:
2026-03-14 20:14:59 +08:00
parent 1836d118fe
commit 6bafd21e02
14 changed files with 1191 additions and 0 deletions

3
tests/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
"""
测试模块初始化
"""

35
tests/conftest.py Normal file
View File

@@ -0,0 +1,35 @@
"""
pytest 配置
"""
import pytest
import sys
import os
# 添加项目根目录到路径
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@pytest.fixture
def app():
"""创建测试应用"""
from app import create_app
app = create_app('testing')
with app.app_context():
from app.extensions import db
db.create_all()
yield app
db.session.remove()
db.drop_all()
@pytest.fixture
def client(app):
"""创建测试客户端"""
return app.test_client()
@pytest.fixture
def runner(app):
"""创建测试运行器"""
return app.test_cli_runner()

158
tests/test_auth.py Normal file
View File

@@ -0,0 +1,158 @@
"""
认证 API 单元测试
"""
import pytest
from app import create_app
from app.extensions import db
from app.models import User
from app.utils.security import hash_password
@pytest.fixture
def app():
"""创建测试应用"""
app = create_app('testing')
with app.app_context():
db.create_all()
yield app
db.drop_all()
@pytest.fixture
def client(app):
"""创建测试客户端"""
return app.test_client()
class TestAuthAPI:
"""认证 API 测试"""
def test_register_success(self, client):
"""测试用户注册成功"""
response = client.post('/api/auth/register', json={
'username': 'testuser',
'email': 'test@example.com',
'password': 'password123',
'nickname': 'Test User'
})
assert response.status_code == 201
data = response.get_json()
assert data['user']['username'] == 'testuser'
assert data['user']['email'] == 'test@example.com'
def test_register_duplicate_username(self, client, app):
"""测试重复用户名"""
# 创建已存在的用户
with app.app_context():
user = User(
username='testuser',
email='existing@example.com',
password_hash=hash_password('password')
)
db.session.add(user)
db.session.commit()
# 尝试注册相同用户名
response = client.post('/api/auth/register', json={
'username': 'testuser',
'email': 'new@example.com',
'password': 'password123'
})
assert response.status_code == 400
def test_register_missing_fields(self, client):
"""测试缺少必填字段"""
response = client.post('/api/auth/register', json={
'username': 'testuser'
})
assert response.status_code == 400
def test_login_success(self, client, app):
"""测试登录成功"""
# 创建用户
with app.app_context():
user = User(
username='testuser',
email='test@example.com',
password_hash=hash_password('password123')
)
db.session.add(user)
db.session.commit()
# 登录
response = client.post('/api/auth/login', json={
'username': 'testuser',
'password': 'password123'
})
assert response.status_code == 200
data = response.get_json()
assert 'access_token' in data
assert 'refresh_token' in data
def test_login_invalid_credentials(self, client):
"""测试无效凭据"""
response = client.post('/api/auth/login', json={
'username': 'nonexistent',
'password': 'wrong'
})
assert response.status_code == 401
def test_verify_token(self, client, app):
"""测试 Token 验证"""
# 创建用户
with app.app_context():
user = User(
username='testuser',
email='test@example.com',
password_hash=hash_password('password123')
)
db.session.add(user)
db.session.commit()
# 登录获取 token
login_response = client.post('/api/auth/login', json={
'username': 'testuser',
'password': 'password123'
})
token = login_response.get_json()['access_token']
# 验证 token
response = client.post('/api/auth/verify', headers={
'Authorization': f'Bearer {token}'
})
assert response.status_code == 200
data = response.get_json()
assert data['valid'] is True
class TestValidation:
"""验证测试"""
def test_validate_email(self, client):
"""测试邮箱验证"""
from app.utils.validators import validate_email
assert validate_email('test@example.com') is True
assert validate_email('invalid-email') is False
def test_validate_username(self, client):
"""测试用户名验证"""
from app.utils.validators import validate_username
assert validate_username('testuser') is True
assert validate_username('ab') is False # 太短
assert validate_username('test-user') is True
assert validate_username('test@user') is False # 包含非法字符
def test_validate_uuid(self, client):
"""测试 UUID 验证"""
from app.utils.validators import validate_uuid
assert validate_uuid('550e8400-e29b-41d4-a716-446655440000') is True
assert validate_uuid('invalid-uuid') is False

View File

@@ -0,0 +1,96 @@
"""
消息队列单元测试
"""
import pytest
from app.services.message_queue import MessageQueue
from unittest.mock import MagicMock, patch
class TestMessageQueue:
"""消息队列测试"""
@pytest.fixture
def mock_redis(self):
"""模拟 Redis 客户端"""
with patch('app.services.message_queue.redis_client') as mock:
mock.hset.return_value = 1
mock.rpush.return_value = 1
mock.lpop.return_value = None
mock.hgetall.return_value = {}
mock.delete.return_value = 1
mock.llen.return_value = 0
yield mock
def test_enqueue(self, mock_redis):
"""测试消息入队"""
message = MagicMock()
message.id = 'msg-123'
message.session_id = 'session-123'
message.sender_type = 'user'
message.sender_id = 'user-123'
message.content = 'Hello'
message.status = 'sent'
message.retry_count = 0
message.created_at = None
result = MessageQueue.enqueue(message)
assert result is True
mock_redis.hset.assert_called()
mock_redis.rpush.assert_called()
def test_dequeue_empty(self, mock_redis):
"""测试空队列出队"""
mock_redis.lpop.return_value = None
result = MessageQueue.dequeue()
assert result is None
def test_dequeue_success(self, mock_redis):
"""测试成功出队"""
mock_redis.lpop.return_value = 'msg-123'
mock_redis.hgetall.return_value = {
'id': 'msg-123',
'session_id': 'session-123',
'content': 'Hello',
'retry_count': '0',
}
result = MessageQueue.dequeue()
assert result is not None
assert result['id'] == 'msg-123'
def test_ack(self, mock_redis):
"""测试消息确认"""
result = MessageQueue.ack('msg-123')
assert result is True
mock_redis.delete.assert_called()
def test_retry_within_limit(self, mock_redis):
"""测试重试(未超限)"""
mock_redis.hget.return_value = '1'
result = MessageQueue.retry('msg-123')
assert result is True
mock_redis.hset.assert_called()
mock_redis.rpush.assert_called()
def test_retry_exceed_limit(self, mock_redis):
"""测试重试(超限)"""
mock_redis.hget.return_value = '3'
result = MessageQueue.retry('msg-123')
assert result is False
def test_get_pending_count(self, mock_redis):
"""测试获取待处理数量"""
mock_redis.llen.return_value = 5
count = MessageQueue.get_pending_count()
assert count == 5

176
tests/test_scheduler.py Normal file
View File

@@ -0,0 +1,176 @@
"""
调度器单元测试
"""
import pytest
from app.services.scheduler import (
AgentScheduler,
RoundRobinScheduler,
WeightedRoundRobinScheduler,
LeastConnectionsScheduler,
LeastResponseTimeScheduler,
CapabilityMatchScheduler,
)
from datetime import datetime, timedelta
class MockAgent:
"""模拟 Agent"""
def __init__(self, id, status='online', weight=10, priority=5,
current_sessions=0, connection_limit=5, capabilities=None,
last_heartbeat=None):
self.id = id
self.status = status
self.weight = weight
self.priority = priority
self.current_sessions = current_sessions
self.connection_limit = connection_limit
self.capabilities = capabilities or []
self.last_heartbeat = last_heartbeat
class TestRoundRobinScheduler:
"""轮询调度器测试"""
def test_select_agent(self):
"""测试选择 Agent"""
scheduler = RoundRobinScheduler()
agents = [
MockAgent(id='agent1'),
MockAgent(id='agent2'),
MockAgent(id='agent3'),
]
# 轮询选择
agent1 = scheduler.select_agent(agents)
agent2 = scheduler.select_agent(agents)
agent3 = scheduler.select_agent(agents)
agent4 = scheduler.select_agent(agents) # 循环回第一个
assert agent1.id == 'agent1'
assert agent2.id == 'agent2'
assert agent3.id == 'agent3'
assert agent4.id == 'agent1'
def test_no_online_agents(self):
"""测试无在线 Agent"""
scheduler = RoundRobinScheduler()
agents = [
MockAgent(id='agent1', status='offline'),
MockAgent(id='agent2', status='offline'),
]
agent = scheduler.select_agent(agents)
assert agent is None
class TestWeightedRoundRobinScheduler:
"""加权轮询调度器测试"""
def test_weight_distribution(self):
"""测试权重分布"""
scheduler = WeightedRoundRobinScheduler()
agents = [
MockAgent(id='agent1', weight=3),
MockAgent(id='agent2', weight=1),
]
# 统计选择次数
counts = {'agent1': 0, 'agent2': 0}
for _ in range(100):
agent = scheduler.select_agent(agents)
counts[agent.id] += 1
# agent1 权重 3agent2 权重 1比例应该约 3:1
assert counts['agent1'] > counts['agent2']
def test_zero_weight(self):
"""测试零权重"""
scheduler = WeightedRoundRobinScheduler()
agents = [
MockAgent(id='agent1', weight=0),
MockAgent(id='agent2', weight=0),
]
# 权重都为 0 时返回第一个
agent = scheduler.select_agent(agents)
assert agent is not None
class TestLeastConnectionsScheduler:
"""最少连接调度器测试"""
def test_select_least_connections(self):
"""测试选择最少连接"""
scheduler = LeastConnectionsScheduler()
agents = [
MockAgent(id='agent1', current_sessions=5),
MockAgent(id='agent2', current_sessions=2),
MockAgent(id='agent3', current_sessions=8),
]
agent = scheduler.select_agent(agents)
assert agent.id == 'agent2'
def test_all_at_limit(self):
"""测试所有 Agent 都达上限"""
scheduler = LeastConnectionsScheduler()
agents = [
MockAgent(id='agent1', current_sessions=5, connection_limit=5),
MockAgent(id='agent2', current_sessions=5, connection_limit=5),
]
agent = scheduler.select_agent(agents)
assert agent is None
class TestCapabilityMatchScheduler:
"""能力匹配调度器测试"""
def test_match_capabilities(self):
"""测试能力匹配"""
scheduler = CapabilityMatchScheduler()
agents = [
MockAgent(id='agent1', capabilities=['chat', 'code']),
MockAgent(id='agent2', capabilities=['chat']),
MockAgent(id='agent3', capabilities=['code', 'translate']),
]
# 需要 code 能力
agent = scheduler.select_agent(agents, {'capabilities': ['code']})
assert agent.id in ['agent1', 'agent3']
def test_no_matching_capabilities(self):
"""测试无匹配能力"""
scheduler = CapabilityMatchScheduler()
agents = [
MockAgent(id='agent1', capabilities=['chat']),
]
# 需要 code 能力但无 Agent 具备
agent = scheduler.select_agent(agents, {'capabilities': ['code']})
# 回退到加权轮询
assert agent is not None
class TestAgentScheduler:
"""Agent 调度器工厂测试"""
def test_available_strategies(self):
"""测试可用策略"""
strategies = AgentScheduler.get_available_strategies()
assert 'round_robin' in strategies
assert 'weighted_round_robin' in strategies
assert 'least_connections' in strategies
assert 'least_response_time' in strategies
assert 'capability_match' in strategies
def test_default_strategy(self):
"""测试默认策略"""
scheduler = AgentScheduler()
assert scheduler.get_strategy() == 'weighted_round_robin'
def test_custom_strategy(self):
"""测试自定义策略"""
scheduler = AgentScheduler('round_robin')
assert scheduler.get_strategy() == 'round_robin'