feat: Phase 3 - 工具层 + 测试 + 数据库迁移
This commit is contained in:
3
tests/__init__.py
Normal file
3
tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
测试模块初始化
|
||||
"""
|
||||
35
tests/conftest.py
Normal file
35
tests/conftest.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
pytest 配置
|
||||
"""
|
||||
import pytest
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""创建测试应用"""
|
||||
from app import create_app
|
||||
app = create_app('testing')
|
||||
|
||||
with app.app_context():
|
||||
from app.extensions import db
|
||||
db.create_all()
|
||||
yield app
|
||||
db.session.remove()
|
||||
db.drop_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""创建测试客户端"""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(app):
|
||||
"""创建测试运行器"""
|
||||
return app.test_cli_runner()
|
||||
158
tests/test_auth.py
Normal file
158
tests/test_auth.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
认证 API 单元测试
|
||||
"""
|
||||
import pytest
|
||||
from app import create_app
|
||||
from app.extensions import db
|
||||
from app.models import User
|
||||
from app.utils.security import hash_password
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""创建测试应用"""
|
||||
app = create_app('testing')
|
||||
with app.app_context():
|
||||
db.create_all()
|
||||
yield app
|
||||
db.drop_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(app):
|
||||
"""创建测试客户端"""
|
||||
return app.test_client()
|
||||
|
||||
|
||||
class TestAuthAPI:
|
||||
"""认证 API 测试"""
|
||||
|
||||
def test_register_success(self, client):
|
||||
"""测试用户注册成功"""
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'testuser',
|
||||
'email': 'test@example.com',
|
||||
'password': 'password123',
|
||||
'nickname': 'Test User'
|
||||
})
|
||||
|
||||
assert response.status_code == 201
|
||||
data = response.get_json()
|
||||
assert data['user']['username'] == 'testuser'
|
||||
assert data['user']['email'] == 'test@example.com'
|
||||
|
||||
def test_register_duplicate_username(self, client, app):
|
||||
"""测试重复用户名"""
|
||||
# 创建已存在的用户
|
||||
with app.app_context():
|
||||
user = User(
|
||||
username='testuser',
|
||||
email='existing@example.com',
|
||||
password_hash=hash_password('password')
|
||||
)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
# 尝试注册相同用户名
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'testuser',
|
||||
'email': 'new@example.com',
|
||||
'password': 'password123'
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_register_missing_fields(self, client):
|
||||
"""测试缺少必填字段"""
|
||||
response = client.post('/api/auth/register', json={
|
||||
'username': 'testuser'
|
||||
})
|
||||
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_login_success(self, client, app):
|
||||
"""测试登录成功"""
|
||||
# 创建用户
|
||||
with app.app_context():
|
||||
user = User(
|
||||
username='testuser',
|
||||
email='test@example.com',
|
||||
password_hash=hash_password('password123')
|
||||
)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
# 登录
|
||||
response = client.post('/api/auth/login', json={
|
||||
'username': 'testuser',
|
||||
'password': 'password123'
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert 'access_token' in data
|
||||
assert 'refresh_token' in data
|
||||
|
||||
def test_login_invalid_credentials(self, client):
|
||||
"""测试无效凭据"""
|
||||
response = client.post('/api/auth/login', json={
|
||||
'username': 'nonexistent',
|
||||
'password': 'wrong'
|
||||
})
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_verify_token(self, client, app):
|
||||
"""测试 Token 验证"""
|
||||
# 创建用户
|
||||
with app.app_context():
|
||||
user = User(
|
||||
username='testuser',
|
||||
email='test@example.com',
|
||||
password_hash=hash_password('password123')
|
||||
)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
# 登录获取 token
|
||||
login_response = client.post('/api/auth/login', json={
|
||||
'username': 'testuser',
|
||||
'password': 'password123'
|
||||
})
|
||||
token = login_response.get_json()['access_token']
|
||||
|
||||
# 验证 token
|
||||
response = client.post('/api/auth/verify', headers={
|
||||
'Authorization': f'Bearer {token}'
|
||||
})
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert data['valid'] is True
|
||||
|
||||
|
||||
class TestValidation:
|
||||
"""验证测试"""
|
||||
|
||||
def test_validate_email(self, client):
|
||||
"""测试邮箱验证"""
|
||||
from app.utils.validators import validate_email
|
||||
|
||||
assert validate_email('test@example.com') is True
|
||||
assert validate_email('invalid-email') is False
|
||||
|
||||
def test_validate_username(self, client):
|
||||
"""测试用户名验证"""
|
||||
from app.utils.validators import validate_username
|
||||
|
||||
assert validate_username('testuser') is True
|
||||
assert validate_username('ab') is False # 太短
|
||||
assert validate_username('test-user') is True
|
||||
assert validate_username('test@user') is False # 包含非法字符
|
||||
|
||||
def test_validate_uuid(self, client):
|
||||
"""测试 UUID 验证"""
|
||||
from app.utils.validators import validate_uuid
|
||||
|
||||
assert validate_uuid('550e8400-e29b-41d4-a716-446655440000') is True
|
||||
assert validate_uuid('invalid-uuid') is False
|
||||
96
tests/test_message_queue.py
Normal file
96
tests/test_message_queue.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
消息队列单元测试
|
||||
"""
|
||||
import pytest
|
||||
from app.services.message_queue import MessageQueue
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
class TestMessageQueue:
|
||||
"""消息队列测试"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis(self):
|
||||
"""模拟 Redis 客户端"""
|
||||
with patch('app.services.message_queue.redis_client') as mock:
|
||||
mock.hset.return_value = 1
|
||||
mock.rpush.return_value = 1
|
||||
mock.lpop.return_value = None
|
||||
mock.hgetall.return_value = {}
|
||||
mock.delete.return_value = 1
|
||||
mock.llen.return_value = 0
|
||||
yield mock
|
||||
|
||||
def test_enqueue(self, mock_redis):
|
||||
"""测试消息入队"""
|
||||
message = MagicMock()
|
||||
message.id = 'msg-123'
|
||||
message.session_id = 'session-123'
|
||||
message.sender_type = 'user'
|
||||
message.sender_id = 'user-123'
|
||||
message.content = 'Hello'
|
||||
message.status = 'sent'
|
||||
message.retry_count = 0
|
||||
message.created_at = None
|
||||
|
||||
result = MessageQueue.enqueue(message)
|
||||
|
||||
assert result is True
|
||||
mock_redis.hset.assert_called()
|
||||
mock_redis.rpush.assert_called()
|
||||
|
||||
def test_dequeue_empty(self, mock_redis):
|
||||
"""测试空队列出队"""
|
||||
mock_redis.lpop.return_value = None
|
||||
|
||||
result = MessageQueue.dequeue()
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_dequeue_success(self, mock_redis):
|
||||
"""测试成功出队"""
|
||||
mock_redis.lpop.return_value = 'msg-123'
|
||||
mock_redis.hgetall.return_value = {
|
||||
'id': 'msg-123',
|
||||
'session_id': 'session-123',
|
||||
'content': 'Hello',
|
||||
'retry_count': '0',
|
||||
}
|
||||
|
||||
result = MessageQueue.dequeue()
|
||||
|
||||
assert result is not None
|
||||
assert result['id'] == 'msg-123'
|
||||
|
||||
def test_ack(self, mock_redis):
|
||||
"""测试消息确认"""
|
||||
result = MessageQueue.ack('msg-123')
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called()
|
||||
|
||||
def test_retry_within_limit(self, mock_redis):
|
||||
"""测试重试(未超限)"""
|
||||
mock_redis.hget.return_value = '1'
|
||||
|
||||
result = MessageQueue.retry('msg-123')
|
||||
|
||||
assert result is True
|
||||
mock_redis.hset.assert_called()
|
||||
mock_redis.rpush.assert_called()
|
||||
|
||||
def test_retry_exceed_limit(self, mock_redis):
|
||||
"""测试重试(超限)"""
|
||||
mock_redis.hget.return_value = '3'
|
||||
|
||||
result = MessageQueue.retry('msg-123')
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_get_pending_count(self, mock_redis):
|
||||
"""测试获取待处理数量"""
|
||||
mock_redis.llen.return_value = 5
|
||||
|
||||
count = MessageQueue.get_pending_count()
|
||||
|
||||
assert count == 5
|
||||
176
tests/test_scheduler.py
Normal file
176
tests/test_scheduler.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
调度器单元测试
|
||||
"""
|
||||
import pytest
|
||||
from app.services.scheduler import (
|
||||
AgentScheduler,
|
||||
RoundRobinScheduler,
|
||||
WeightedRoundRobinScheduler,
|
||||
LeastConnectionsScheduler,
|
||||
LeastResponseTimeScheduler,
|
||||
CapabilityMatchScheduler,
|
||||
)
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class MockAgent:
|
||||
"""模拟 Agent"""
|
||||
def __init__(self, id, status='online', weight=10, priority=5,
|
||||
current_sessions=0, connection_limit=5, capabilities=None,
|
||||
last_heartbeat=None):
|
||||
self.id = id
|
||||
self.status = status
|
||||
self.weight = weight
|
||||
self.priority = priority
|
||||
self.current_sessions = current_sessions
|
||||
self.connection_limit = connection_limit
|
||||
self.capabilities = capabilities or []
|
||||
self.last_heartbeat = last_heartbeat
|
||||
|
||||
|
||||
class TestRoundRobinScheduler:
|
||||
"""轮询调度器测试"""
|
||||
|
||||
def test_select_agent(self):
|
||||
"""测试选择 Agent"""
|
||||
scheduler = RoundRobinScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1'),
|
||||
MockAgent(id='agent2'),
|
||||
MockAgent(id='agent3'),
|
||||
]
|
||||
|
||||
# 轮询选择
|
||||
agent1 = scheduler.select_agent(agents)
|
||||
agent2 = scheduler.select_agent(agents)
|
||||
agent3 = scheduler.select_agent(agents)
|
||||
agent4 = scheduler.select_agent(agents) # 循环回第一个
|
||||
|
||||
assert agent1.id == 'agent1'
|
||||
assert agent2.id == 'agent2'
|
||||
assert agent3.id == 'agent3'
|
||||
assert agent4.id == 'agent1'
|
||||
|
||||
def test_no_online_agents(self):
|
||||
"""测试无在线 Agent"""
|
||||
scheduler = RoundRobinScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', status='offline'),
|
||||
MockAgent(id='agent2', status='offline'),
|
||||
]
|
||||
|
||||
agent = scheduler.select_agent(agents)
|
||||
assert agent is None
|
||||
|
||||
|
||||
class TestWeightedRoundRobinScheduler:
|
||||
"""加权轮询调度器测试"""
|
||||
|
||||
def test_weight_distribution(self):
|
||||
"""测试权重分布"""
|
||||
scheduler = WeightedRoundRobinScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', weight=3),
|
||||
MockAgent(id='agent2', weight=1),
|
||||
]
|
||||
|
||||
# 统计选择次数
|
||||
counts = {'agent1': 0, 'agent2': 0}
|
||||
for _ in range(100):
|
||||
agent = scheduler.select_agent(agents)
|
||||
counts[agent.id] += 1
|
||||
|
||||
# agent1 权重 3,agent2 权重 1,比例应该约 3:1
|
||||
assert counts['agent1'] > counts['agent2']
|
||||
|
||||
def test_zero_weight(self):
|
||||
"""测试零权重"""
|
||||
scheduler = WeightedRoundRobinScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', weight=0),
|
||||
MockAgent(id='agent2', weight=0),
|
||||
]
|
||||
|
||||
# 权重都为 0 时返回第一个
|
||||
agent = scheduler.select_agent(agents)
|
||||
assert agent is not None
|
||||
|
||||
|
||||
class TestLeastConnectionsScheduler:
|
||||
"""最少连接调度器测试"""
|
||||
|
||||
def test_select_least_connections(self):
|
||||
"""测试选择最少连接"""
|
||||
scheduler = LeastConnectionsScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', current_sessions=5),
|
||||
MockAgent(id='agent2', current_sessions=2),
|
||||
MockAgent(id='agent3', current_sessions=8),
|
||||
]
|
||||
|
||||
agent = scheduler.select_agent(agents)
|
||||
assert agent.id == 'agent2'
|
||||
|
||||
def test_all_at_limit(self):
|
||||
"""测试所有 Agent 都达上限"""
|
||||
scheduler = LeastConnectionsScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', current_sessions=5, connection_limit=5),
|
||||
MockAgent(id='agent2', current_sessions=5, connection_limit=5),
|
||||
]
|
||||
|
||||
agent = scheduler.select_agent(agents)
|
||||
assert agent is None
|
||||
|
||||
|
||||
class TestCapabilityMatchScheduler:
|
||||
"""能力匹配调度器测试"""
|
||||
|
||||
def test_match_capabilities(self):
|
||||
"""测试能力匹配"""
|
||||
scheduler = CapabilityMatchScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', capabilities=['chat', 'code']),
|
||||
MockAgent(id='agent2', capabilities=['chat']),
|
||||
MockAgent(id='agent3', capabilities=['code', 'translate']),
|
||||
]
|
||||
|
||||
# 需要 code 能力
|
||||
agent = scheduler.select_agent(agents, {'capabilities': ['code']})
|
||||
assert agent.id in ['agent1', 'agent3']
|
||||
|
||||
def test_no_matching_capabilities(self):
|
||||
"""测试无匹配能力"""
|
||||
scheduler = CapabilityMatchScheduler()
|
||||
agents = [
|
||||
MockAgent(id='agent1', capabilities=['chat']),
|
||||
]
|
||||
|
||||
# 需要 code 能力但无 Agent 具备
|
||||
agent = scheduler.select_agent(agents, {'capabilities': ['code']})
|
||||
# 回退到加权轮询
|
||||
assert agent is not None
|
||||
|
||||
|
||||
class TestAgentScheduler:
|
||||
"""Agent 调度器工厂测试"""
|
||||
|
||||
def test_available_strategies(self):
|
||||
"""测试可用策略"""
|
||||
strategies = AgentScheduler.get_available_strategies()
|
||||
|
||||
assert 'round_robin' in strategies
|
||||
assert 'weighted_round_robin' in strategies
|
||||
assert 'least_connections' in strategies
|
||||
assert 'least_response_time' in strategies
|
||||
assert 'capability_match' in strategies
|
||||
|
||||
def test_default_strategy(self):
|
||||
"""测试默认策略"""
|
||||
scheduler = AgentScheduler()
|
||||
assert scheduler.get_strategy() == 'weighted_round_robin'
|
||||
|
||||
def test_custom_strategy(self):
|
||||
"""测试自定义策略"""
|
||||
scheduler = AgentScheduler('round_robin')
|
||||
assert scheduler.get_strategy() == 'round_robin'
|
||||
Reference in New Issue
Block a user