feat: Phase 3 - 工具层 + 测试 + 数据库迁移
This commit is contained in:
40
CHANGELOG.md
40
CHANGELOG.md
@@ -11,6 +11,46 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## [0.5.0] - 2026-03-14
|
||||||
|
|
||||||
|
### Added
|
||||||
|
|
||||||
|
#### 🔧 工具层实现 (Phase 3)
|
||||||
|
|
||||||
|
- **validators.py** - 输入验证工具
|
||||||
|
- Marshmallow Schema 验证
|
||||||
|
- UUID/Email/用户名/URL 验证
|
||||||
|
- 字符串清理
|
||||||
|
|
||||||
|
- **security.py** - 安全工具
|
||||||
|
- Token/密码哈希
|
||||||
|
- API Key 生成
|
||||||
|
- RateLimiter 限流器
|
||||||
|
- IPWhitelist IP 白名单
|
||||||
|
|
||||||
|
- **helpers.py** - 辅助函数
|
||||||
|
- 日期时间格式化
|
||||||
|
- 分页辅助
|
||||||
|
- JSON 安全解析
|
||||||
|
|
||||||
|
#### 🧪 测试模块
|
||||||
|
|
||||||
|
- **conftest.py** - pytest 配置
|
||||||
|
- **test_auth.py** - 认证 API 单元测试
|
||||||
|
- **test_scheduler.py** - 调度器单元测试
|
||||||
|
- **test_message_queue.py** - 消息队列单元测试
|
||||||
|
|
||||||
|
#### 🗄️ 数据库迁移
|
||||||
|
|
||||||
|
- **Alembic 配置** - 数据库迁移工具
|
||||||
|
- **初始迁移脚本** - 创建所有数据表
|
||||||
|
|
||||||
|
### Changed
|
||||||
|
|
||||||
|
- **requirements.txt** - 添加 gevent 依赖(用于 gunicorn)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## [0.4.0] - 2026-03-14
|
## [0.4.0] - 2026-03-14
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|||||||
@@ -1,3 +1,81 @@
|
|||||||
"""
|
"""
|
||||||
工具模块
|
工具模块
|
||||||
"""
|
"""
|
||||||
|
from .validators import (
|
||||||
|
UserRegistrationSchema,
|
||||||
|
UserLoginSchema,
|
||||||
|
SessionCreateSchema,
|
||||||
|
MessageSendSchema,
|
||||||
|
AgentRegistrationSchema,
|
||||||
|
GatewayRegistrationSchema,
|
||||||
|
validate_uuid,
|
||||||
|
validate_email,
|
||||||
|
validate_username,
|
||||||
|
validate_url,
|
||||||
|
sanitize_string,
|
||||||
|
ValidationUtils,
|
||||||
|
)
|
||||||
|
from .security import (
|
||||||
|
generate_token,
|
||||||
|
hash_password,
|
||||||
|
verify_password,
|
||||||
|
hash_token,
|
||||||
|
verify_token_hash,
|
||||||
|
generate_api_key,
|
||||||
|
secure_compare,
|
||||||
|
mask_sensitive_data,
|
||||||
|
RateLimiter,
|
||||||
|
IPWhitelist,
|
||||||
|
)
|
||||||
|
from .helpers import (
|
||||||
|
format_datetime,
|
||||||
|
parse_datetime,
|
||||||
|
format_duration,
|
||||||
|
truncate_string,
|
||||||
|
safe_json_loads,
|
||||||
|
safe_json_dumps,
|
||||||
|
generate_session_title,
|
||||||
|
calculate_timeout,
|
||||||
|
merge_dicts,
|
||||||
|
filter_none_values,
|
||||||
|
PaginationHelper,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Validators
|
||||||
|
'UserRegistrationSchema',
|
||||||
|
'UserLoginSchema',
|
||||||
|
'SessionCreateSchema',
|
||||||
|
'MessageSendSchema',
|
||||||
|
'AgentRegistrationSchema',
|
||||||
|
'GatewayRegistrationSchema',
|
||||||
|
'validate_uuid',
|
||||||
|
'validate_email',
|
||||||
|
'validate_username',
|
||||||
|
'validate_url',
|
||||||
|
'sanitize_string',
|
||||||
|
'ValidationUtils',
|
||||||
|
# Security
|
||||||
|
'generate_token',
|
||||||
|
'hash_password',
|
||||||
|
'verify_password',
|
||||||
|
'hash_token',
|
||||||
|
'verify_token_hash',
|
||||||
|
'generate_api_key',
|
||||||
|
'secure_compare',
|
||||||
|
'mask_sensitive_data',
|
||||||
|
'RateLimiter',
|
||||||
|
'IPWhitelist',
|
||||||
|
# Helpers
|
||||||
|
'format_datetime',
|
||||||
|
'parse_datetime',
|
||||||
|
'format_duration',
|
||||||
|
'truncate_string',
|
||||||
|
'safe_json_loads',
|
||||||
|
'safe_json_dumps',
|
||||||
|
'generate_session_title',
|
||||||
|
'calculate_timeout',
|
||||||
|
'merge_dicts',
|
||||||
|
'filter_none_values',
|
||||||
|
'PaginationHelper',
|
||||||
|
]
|
||||||
|
|||||||
113
app/utils/helpers.py
Normal file
113
app/utils/helpers.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
"""
|
||||||
|
辅助函数
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, Any
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def format_datetime(dt: Optional[datetime]) -> Optional[str]:
|
||||||
|
"""格式化日期时间"""
|
||||||
|
if not dt:
|
||||||
|
return None
|
||||||
|
return dt.isoformat()
|
||||||
|
|
||||||
|
|
||||||
|
def parse_datetime(dt_str: str) -> Optional[datetime]:
|
||||||
|
"""解析日期时间字符串"""
|
||||||
|
if not dt_str:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(dt_str)
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def format_duration(seconds: int) -> str:
|
||||||
|
"""格式化时长"""
|
||||||
|
if seconds < 60:
|
||||||
|
return f"{seconds}s"
|
||||||
|
elif seconds < 3600:
|
||||||
|
return f"{seconds // 60}m"
|
||||||
|
elif seconds < 86400:
|
||||||
|
return f"{seconds // 3600}h"
|
||||||
|
else:
|
||||||
|
return f"{seconds // 86400}d"
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_string(s: str, max_length: int = 100, suffix: str = "...") -> str:
|
||||||
|
"""截断字符串"""
|
||||||
|
if len(s) <= max_length:
|
||||||
|
return s
|
||||||
|
return s[:max_length - len(suffix)] + suffix
|
||||||
|
|
||||||
|
|
||||||
|
def safe_json_loads(data: str, default: Any = None) -> Any:
|
||||||
|
"""安全的 JSON 解析"""
|
||||||
|
try:
|
||||||
|
return json.loads(data)
|
||||||
|
except:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def safe_json_dumps(data: Any, default: str = "{}") -> str:
|
||||||
|
"""安全的 JSON 序列化"""
|
||||||
|
try:
|
||||||
|
return json.dumps(data, ensure_ascii=False)
|
||||||
|
except:
|
||||||
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def generate_session_title() -> str:
|
||||||
|
"""生成默认会话标题"""
|
||||||
|
return f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}"
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_timeout(start_time: datetime, timeout_seconds: int) -> bool:
|
||||||
|
"""检查是否超时"""
|
||||||
|
if not start_time:
|
||||||
|
return True
|
||||||
|
|
||||||
|
elapsed = (datetime.utcnow() - start_time).total_seconds()
|
||||||
|
return elapsed > timeout_seconds
|
||||||
|
|
||||||
|
|
||||||
|
def merge_dicts(base: dict, override: dict) -> dict:
|
||||||
|
"""合并字典"""
|
||||||
|
result = base.copy()
|
||||||
|
result.update(override)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def filter_none_values(data: dict) -> dict:
|
||||||
|
"""过滤字典中的 None 值"""
|
||||||
|
return {k: v for k, v in data.items() if v is not None}
|
||||||
|
|
||||||
|
|
||||||
|
class PaginationHelper:
|
||||||
|
"""分页辅助类"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def paginate(query, page: int = 1, per_page: int = 20):
|
||||||
|
"""分页查询"""
|
||||||
|
if page < 1:
|
||||||
|
page = 1
|
||||||
|
if per_page < 1:
|
||||||
|
per_page = 20
|
||||||
|
if per_page > 100:
|
||||||
|
per_page = 100
|
||||||
|
|
||||||
|
total = query.count()
|
||||||
|
items = query.offset((page - 1) * per_page).limit(per_page).all()
|
||||||
|
|
||||||
|
pages = (total + per_page - 1) // per_page
|
||||||
|
|
||||||
|
return {
|
||||||
|
'items': items,
|
||||||
|
'total': total,
|
||||||
|
'page': page,
|
||||||
|
'per_page': per_page,
|
||||||
|
'pages': pages,
|
||||||
|
'has_next': page < pages,
|
||||||
|
'has_prev': page > 1,
|
||||||
|
}
|
||||||
121
app/utils/security.py
Normal file
121
app/utils/security.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
安全工具
|
||||||
|
"""
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import secrets
|
||||||
|
import bcrypt
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def generate_token(length: int = 32) -> str:
|
||||||
|
"""生成随机 Token"""
|
||||||
|
return secrets.token_urlsafe(length)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
"""密码哈希"""
|
||||||
|
return bcrypt.hashpw(
|
||||||
|
password.encode('utf-8'),
|
||||||
|
bcrypt.gensalt()
|
||||||
|
).decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(password: str, password_hash: str) -> bool:
|
||||||
|
"""验证密码"""
|
||||||
|
return bcrypt.checkpw(
|
||||||
|
password.encode('utf-8'),
|
||||||
|
password_hash.encode('utf-8')
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_token(token: str) -> str:
|
||||||
|
"""Token 哈希(用于存储)"""
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_token_hash(token: str, token_hash: str) -> bool:
|
||||||
|
"""验证 Token 哈希"""
|
||||||
|
return hash_token(token) == token_hash
|
||||||
|
|
||||||
|
|
||||||
|
def generate_api_key() -> str:
|
||||||
|
"""生成 API Key"""
|
||||||
|
return f"pit_{secrets.token_urlsafe(32)}"
|
||||||
|
|
||||||
|
|
||||||
|
def secure_compare(val1: str, val2: str) -> bool:
|
||||||
|
"""安全字符串比较(防时序攻击)"""
|
||||||
|
return hmac.compare_digest(val1.encode(), val2.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def mask_sensitive_data(data: str, visible_chars: int = 4) -> str:
|
||||||
|
"""脱敏处理"""
|
||||||
|
if len(data) <= visible_chars * 2:
|
||||||
|
return '*' * len(data)
|
||||||
|
|
||||||
|
return data[:visible_chars] + '*' * (len(data) - visible_chars * 2) + data[-visible_chars:]
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimiter:
|
||||||
|
"""简单内存限流器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._storage = {}
|
||||||
|
|
||||||
|
def is_allowed(self, key: str, limit: int = 100, window: int = 60) -> bool:
|
||||||
|
"""检查是否允许请求"""
|
||||||
|
from time import time
|
||||||
|
|
||||||
|
now = time()
|
||||||
|
window_start = now - window
|
||||||
|
|
||||||
|
# 获取该 key 的请求记录
|
||||||
|
requests = self._storage.get(key, [])
|
||||||
|
|
||||||
|
# 清理过期记录
|
||||||
|
requests = [t for t in requests if t > window_start]
|
||||||
|
|
||||||
|
# 检查是否超过限制
|
||||||
|
if len(requests) >= limit:
|
||||||
|
self._storage[key] = requests
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 记录新请求
|
||||||
|
requests.append(now)
|
||||||
|
self._storage[key] = requests
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class IPWhitelist:
|
||||||
|
"""IP 白名单"""
|
||||||
|
|
||||||
|
def __init__(self, allowed_ips: list = None):
|
||||||
|
self.allowed_ips = set(allowed_ips or [])
|
||||||
|
self.allow_all = '*' in self.allowed_ips
|
||||||
|
|
||||||
|
def is_allowed(self, ip: str) -> bool:
|
||||||
|
"""检查 IP 是否允许"""
|
||||||
|
if self.allow_all:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 支持 CIDR 格式
|
||||||
|
for allowed in self.allowed_ips:
|
||||||
|
if '/' in allowed:
|
||||||
|
if self._ip_in_cidr(ip, allowed):
|
||||||
|
return True
|
||||||
|
elif ip == allowed:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _ip_in_cidr(self, ip: str, cidr: str) -> bool:
|
||||||
|
"""检查 IP 是否在 CIDR 范围内"""
|
||||||
|
try:
|
||||||
|
import ipaddress
|
||||||
|
network = ipaddress.ip_network(cidr, strict=False)
|
||||||
|
address = ipaddress.ip_address(ip)
|
||||||
|
return address in network
|
||||||
|
except:
|
||||||
|
return False
|
||||||
120
app/utils/validators.py
Normal file
120
app/utils/validators.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""
|
||||||
|
输入验证工具
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
import re
|
||||||
|
from marshmallow import Schema, fields, validate, ValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class UserRegistrationSchema(Schema):
|
||||||
|
"""用户注册验证"""
|
||||||
|
username = fields.Str(required=True, validate=validate.Length(min=3, max=80))
|
||||||
|
email = fields.Email(required=True)
|
||||||
|
password = fields.Str(required=True, validate=validate.Length(min=6, max=128))
|
||||||
|
nickname = fields.Str(validate=validate.Length(max=80), load_default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class UserLoginSchema(Schema):
|
||||||
|
"""用户登录验证"""
|
||||||
|
username = fields.Str(required=True)
|
||||||
|
password = fields.Str(required=True)
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCreateSchema(Schema):
|
||||||
|
"""会话创建验证"""
|
||||||
|
title = fields.Str(validate=validate.Length(max=200), load_default=None)
|
||||||
|
agent_id = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||||||
|
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageSendSchema(Schema):
|
||||||
|
"""消息发送验证"""
|
||||||
|
session_id = fields.Str(required=True, validate=validate.Length(max=36))
|
||||||
|
content = fields.Str(required=True, validate=validate.Length(min=1, max=10000))
|
||||||
|
message_type = fields.Str(validate=validate.OneOf(['text', 'media', 'system']), load_default='text')
|
||||||
|
reply_to = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRegistrationSchema(Schema):
|
||||||
|
"""Agent 注册验证"""
|
||||||
|
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
|
||||||
|
gateway_id = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||||||
|
model = fields.Str(validate=validate.Length(max=80), load_default=None)
|
||||||
|
capabilities = fields.List(fields.Str(), load_default=[])
|
||||||
|
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
|
||||||
|
weight = fields.Int(validate=validate.Range(min=1, max=100), load_default=10)
|
||||||
|
|
||||||
|
|
||||||
|
class GatewayRegistrationSchema(Schema):
|
||||||
|
"""Gateway 注册验证"""
|
||||||
|
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
|
||||||
|
url = fields.Str(required=True, validate=validate.Length(max=256))
|
||||||
|
token = fields.Str(validate=validate.Length(max=256), load_default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_uuid(uuid_str: str) -> bool:
|
||||||
|
"""验证 UUID 格式"""
|
||||||
|
pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$'
|
||||||
|
return bool(re.match(pattern, uuid_str.lower()))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_email(email: str) -> bool:
|
||||||
|
"""验证邮箱格式"""
|
||||||
|
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||||||
|
return bool(re.match(pattern, email))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_username(username: str) -> bool:
|
||||||
|
"""验证用户名格式"""
|
||||||
|
# 字母、数字、下划线,3-80 字符
|
||||||
|
pattern = r'^[a-zA-Z0-9_]{3,80}$'
|
||||||
|
return bool(re.match(pattern, username))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_url(url: str) -> bool:
|
||||||
|
"""验证 URL 格式"""
|
||||||
|
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
|
||||||
|
return bool(re.match(pattern, url, re.IGNORECASE))
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_string(input_str: str, max_length: int = 1000) -> str:
|
||||||
|
"""清理字符串输入"""
|
||||||
|
if not input_str:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 去除首尾空格
|
||||||
|
cleaned = input_str.strip()
|
||||||
|
|
||||||
|
# 限制长度
|
||||||
|
if len(cleaned) > max_length:
|
||||||
|
cleaned = cleaned[:max_length]
|
||||||
|
|
||||||
|
# 去除危险字符
|
||||||
|
cleaned = cleaned.replace('\x00', '')
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationUtils:
|
||||||
|
"""验证工具类"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_json(data: dict, schema: Schema) -> tuple:
|
||||||
|
"""验证 JSON 数据"""
|
||||||
|
try:
|
||||||
|
result = schema.load(data)
|
||||||
|
return True, result, None
|
||||||
|
except ValidationError as e:
|
||||||
|
return False, None, e.messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def validate_pagination(page: int, per_page: int, max_per_page: int = 100) -> tuple:
|
||||||
|
"""验证分页参数"""
|
||||||
|
if page < 1:
|
||||||
|
page = 1
|
||||||
|
if per_page < 1:
|
||||||
|
per_page = 20
|
||||||
|
if per_page > max_per_page:
|
||||||
|
per_page = max_per_page
|
||||||
|
|
||||||
|
return page, per_page
|
||||||
48
migrations/alembic.ini
Normal file
48
migrations/alembic.ini
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
[alembic]
|
||||||
|
# 脚本位置
|
||||||
|
script_location = migrations
|
||||||
|
|
||||||
|
# 模板文件
|
||||||
|
template_file =
|
||||||
|
|
||||||
|
# 最大保留版本数
|
||||||
|
max_num = 10
|
||||||
|
|
||||||
|
# 数据库 URL(运行时覆盖)
|
||||||
|
sqlalchemy.url = sqlite:///pit_router.db
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
68
migrations/env.py
Normal file
68
migrations/env.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
Alembic 环境配置
|
||||||
|
"""
|
||||||
|
from logging.config import fileConfig
|
||||||
|
from sqlalchemy import engine_from_config
|
||||||
|
from sqlalchemy import pool
|
||||||
|
from alembic import context
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# 添加项目路径
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from app.extensions import db
|
||||||
|
from app.config import Config
|
||||||
|
|
||||||
|
# Alembic Config 对象
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# 模型元数据
|
||||||
|
target_metadata = db.Model.metadata
|
||||||
|
|
||||||
|
def get_url():
|
||||||
|
"""获取数据库 URL"""
|
||||||
|
return os.getenv('DATABASE_URL', 'sqlite:///pit_router.db')
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""离线迁移"""
|
||||||
|
url = get_url()
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
"""在线迁移"""
|
||||||
|
configuration = config.get_section(config.config_ini_section)
|
||||||
|
configuration["sqlalchemy.url"] = get_url()
|
||||||
|
connectable = engine_from_config(
|
||||||
|
configuration,
|
||||||
|
prefix="sqlalchemy.",
|
||||||
|
poolclass=pool.NullPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
with connectable.connect() as connection:
|
||||||
|
context.configure(
|
||||||
|
connection=connection, target_metadata=target_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
134
migrations/versions/initial.py
Normal file
134
migrations/versions/initial.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""
|
||||||
|
初始化迁移脚本
|
||||||
|
Revision ID: initial
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-14
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers
|
||||||
|
revision = 'initial'
|
||||||
|
down_revision = None
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
"""创建表"""
|
||||||
|
# 用户表
|
||||||
|
op.create_table(
|
||||||
|
'users',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('username', sa.String(80), unique=True, nullable=False),
|
||||||
|
sa.Column('password_hash', sa.String(256), nullable=False),
|
||||||
|
sa.Column('email', sa.String(120), unique=True, nullable=False),
|
||||||
|
sa.Column('nickname', sa.String(80)),
|
||||||
|
sa.Column('role', sa.String(20), default='user'),
|
||||||
|
sa.Column('status', sa.String(20), default='active'),
|
||||||
|
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
|
||||||
|
sa.Column('last_login_at', sa.DateTime),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Gateway 表
|
||||||
|
op.create_table(
|
||||||
|
'gateways',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('name', sa.String(80), unique=True, nullable=False),
|
||||||
|
sa.Column('url', sa.String(256), nullable=False),
|
||||||
|
sa.Column('token_hash', sa.String(256)),
|
||||||
|
sa.Column('status', sa.String(20), default='offline'),
|
||||||
|
sa.Column('agent_count', sa.Integer, default=0),
|
||||||
|
sa.Column('connection_limit', sa.Integer, default=10),
|
||||||
|
sa.Column('heartbeat_interval', sa.Integer, default=60),
|
||||||
|
sa.Column('allowed_ips', sa.JSON),
|
||||||
|
sa.Column('last_heartbeat', sa.DateTime),
|
||||||
|
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent 表
|
||||||
|
op.create_table(
|
||||||
|
'agents',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('name', sa.String(80), nullable=False),
|
||||||
|
sa.Column('display_name', sa.String(80)),
|
||||||
|
sa.Column('gateway_id', sa.String(36), sa.ForeignKey('gateways.id')),
|
||||||
|
sa.Column('socket_id', sa.String(100)),
|
||||||
|
sa.Column('model', sa.String(80)),
|
||||||
|
sa.Column('capabilities', sa.JSON),
|
||||||
|
sa.Column('status', sa.String(20), default='offline'),
|
||||||
|
sa.Column('priority', sa.Integer, default=5),
|
||||||
|
sa.Column('weight', sa.Integer, default=10),
|
||||||
|
sa.Column('connection_limit', sa.Integer, default=5),
|
||||||
|
sa.Column('current_sessions', sa.Integer, default=0),
|
||||||
|
sa.Column('last_heartbeat', sa.DateTime),
|
||||||
|
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 会话表
|
||||||
|
op.create_table(
|
||||||
|
'sessions',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False),
|
||||||
|
sa.Column('primary_agent_id', sa.String(36), sa.ForeignKey('agents.id')),
|
||||||
|
sa.Column('participating_agent_ids', sa.JSON),
|
||||||
|
sa.Column('user_socket_id', sa.String(100)),
|
||||||
|
sa.Column('title', sa.String(200)),
|
||||||
|
sa.Column('channel_type', sa.String(20), default='web'),
|
||||||
|
sa.Column('status', sa.String(20), default='active'),
|
||||||
|
sa.Column('message_count', sa.Integer, default=0),
|
||||||
|
sa.Column('unread_count', sa.Integer, default=0),
|
||||||
|
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
|
||||||
|
sa.Column('updated_at', sa.DateTime),
|
||||||
|
sa.Column('last_active_at', sa.DateTime),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 消息表
|
||||||
|
op.create_table(
|
||||||
|
'messages',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('session_id', sa.String(36), sa.ForeignKey('sessions.id'), nullable=False),
|
||||||
|
sa.Column('sender_type', sa.String(20), nullable=False),
|
||||||
|
sa.Column('sender_id', sa.String(36), nullable=False),
|
||||||
|
sa.Column('message_type', sa.String(20), default='text'),
|
||||||
|
sa.Column('content', sa.Text),
|
||||||
|
sa.Column('content_type', sa.String(20), default='markdown'),
|
||||||
|
sa.Column('reply_to', sa.String(36)),
|
||||||
|
sa.Column('status', sa.String(20), default='sent'),
|
||||||
|
sa.Column('ack_status', sa.String(20), default='pending'),
|
||||||
|
sa.Column('retry_count', sa.Integer, default=0),
|
||||||
|
sa.Column('created_at', sa.DateTime, default=sa.func.now()),
|
||||||
|
sa.Column('delivered_at', sa.DateTime),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 连接表
|
||||||
|
op.create_table(
|
||||||
|
'connections',
|
||||||
|
sa.Column('id', sa.String(36), primary_key=True),
|
||||||
|
sa.Column('socket_id', sa.String(100), unique=True, nullable=False),
|
||||||
|
sa.Column('connection_type', sa.String(20), nullable=False),
|
||||||
|
sa.Column('entity_id', sa.String(36), nullable=False),
|
||||||
|
sa.Column('entity_type', sa.String(20), nullable=False),
|
||||||
|
sa.Column('ip_address', sa.String(45)),
|
||||||
|
sa.Column('user_agent', sa.String(500)),
|
||||||
|
sa.Column('status', sa.String(20), default='connected'),
|
||||||
|
sa.Column('auth_token', sa.String(500)),
|
||||||
|
sa.Column('connected_at', sa.DateTime, default=sa.func.now()),
|
||||||
|
sa.Column('last_activity', sa.DateTime),
|
||||||
|
sa.Column('disconnected_at', sa.DateTime),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建索引
|
||||||
|
op.create_index('ix_messages_session_id', 'messages', ['session_id'])
|
||||||
|
op.create_index('ix_sessions_user_id', 'sessions', ['user_id'])
|
||||||
|
op.create_index('ix_sessions_status', 'sessions', ['status'])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
"""删除表"""
|
||||||
|
op.drop_table('connections')
|
||||||
|
op.drop_table('messages')
|
||||||
|
op.drop_table('sessions')
|
||||||
|
op.drop_table('agents')
|
||||||
|
op.drop_table('gateways')
|
||||||
|
op.drop_table('users')
|
||||||
@@ -7,6 +7,7 @@ Flask-Login==0.6.3
|
|||||||
Flask-Migrate==4.0.5
|
Flask-Migrate==4.0.5
|
||||||
Flask-SocketIO==5.3.6
|
Flask-SocketIO==5.3.6
|
||||||
Flask-SQLAlchemy==3.1.1
|
Flask-SQLAlchemy==3.1.1
|
||||||
|
gevent==24.2.1
|
||||||
gunicorn==21.2.0
|
gunicorn==21.2.0
|
||||||
psycopg2-binary==2.9.9
|
psycopg2-binary==2.9.9
|
||||||
PyJWT==2.8.0
|
PyJWT==2.8.0
|
||||||
|
|||||||
3
tests/__init__.py
Normal file
3
tests/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
测试模块初始化
|
||||||
|
"""
|
||||||
35
tests/conftest.py
Normal file
35
tests/conftest.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""
|
||||||
|
pytest 配置
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 添加项目根目录到路径
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app():
|
||||||
|
"""创建测试应用"""
|
||||||
|
from app import create_app
|
||||||
|
app = create_app('testing')
|
||||||
|
|
||||||
|
with app.app_context():
|
||||||
|
from app.extensions import db
|
||||||
|
db.create_all()
|
||||||
|
yield app
|
||||||
|
db.session.remove()
|
||||||
|
db.drop_all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app):
|
||||||
|
"""创建测试客户端"""
|
||||||
|
return app.test_client()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def runner(app):
|
||||||
|
"""创建测试运行器"""
|
||||||
|
return app.test_cli_runner()
|
||||||
158
tests/test_auth.py
Normal file
158
tests/test_auth.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""
|
||||||
|
认证 API 单元测试
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from app import create_app
|
||||||
|
from app.extensions import db
|
||||||
|
from app.models import User
|
||||||
|
from app.utils.security import hash_password
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app():
|
||||||
|
"""创建测试应用"""
|
||||||
|
app = create_app('testing')
|
||||||
|
with app.app_context():
|
||||||
|
db.create_all()
|
||||||
|
yield app
|
||||||
|
db.drop_all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(app):
|
||||||
|
"""创建测试客户端"""
|
||||||
|
return app.test_client()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthAPI:
|
||||||
|
"""认证 API 测试"""
|
||||||
|
|
||||||
|
def test_register_success(self, client):
|
||||||
|
"""测试用户注册成功"""
|
||||||
|
response = client.post('/api/auth/register', json={
|
||||||
|
'username': 'testuser',
|
||||||
|
'email': 'test@example.com',
|
||||||
|
'password': 'password123',
|
||||||
|
'nickname': 'Test User'
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 201
|
||||||
|
data = response.get_json()
|
||||||
|
assert data['user']['username'] == 'testuser'
|
||||||
|
assert data['user']['email'] == 'test@example.com'
|
||||||
|
|
||||||
|
def test_register_duplicate_username(self, client, app):
|
||||||
|
"""测试重复用户名"""
|
||||||
|
# 创建已存在的用户
|
||||||
|
with app.app_context():
|
||||||
|
user = User(
|
||||||
|
username='testuser',
|
||||||
|
email='existing@example.com',
|
||||||
|
password_hash=hash_password('password')
|
||||||
|
)
|
||||||
|
db.session.add(user)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# 尝试注册相同用户名
|
||||||
|
response = client.post('/api/auth/register', json={
|
||||||
|
'username': 'testuser',
|
||||||
|
'email': 'new@example.com',
|
||||||
|
'password': 'password123'
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
def test_register_missing_fields(self, client):
|
||||||
|
"""测试缺少必填字段"""
|
||||||
|
response = client.post('/api/auth/register', json={
|
||||||
|
'username': 'testuser'
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
|
||||||
|
def test_login_success(self, client, app):
|
||||||
|
"""测试登录成功"""
|
||||||
|
# 创建用户
|
||||||
|
with app.app_context():
|
||||||
|
user = User(
|
||||||
|
username='testuser',
|
||||||
|
email='test@example.com',
|
||||||
|
password_hash=hash_password('password123')
|
||||||
|
)
|
||||||
|
db.session.add(user)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# 登录
|
||||||
|
response = client.post('/api/auth/login', json={
|
||||||
|
'username': 'testuser',
|
||||||
|
'password': 'password123'
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.get_json()
|
||||||
|
assert 'access_token' in data
|
||||||
|
assert 'refresh_token' in data
|
||||||
|
|
||||||
|
def test_login_invalid_credentials(self, client):
|
||||||
|
"""测试无效凭据"""
|
||||||
|
response = client.post('/api/auth/login', json={
|
||||||
|
'username': 'nonexistent',
|
||||||
|
'password': 'wrong'
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_verify_token(self, client, app):
|
||||||
|
"""测试 Token 验证"""
|
||||||
|
# 创建用户
|
||||||
|
with app.app_context():
|
||||||
|
user = User(
|
||||||
|
username='testuser',
|
||||||
|
email='test@example.com',
|
||||||
|
password_hash=hash_password('password123')
|
||||||
|
)
|
||||||
|
db.session.add(user)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# 登录获取 token
|
||||||
|
login_response = client.post('/api/auth/login', json={
|
||||||
|
'username': 'testuser',
|
||||||
|
'password': 'password123'
|
||||||
|
})
|
||||||
|
token = login_response.get_json()['access_token']
|
||||||
|
|
||||||
|
# 验证 token
|
||||||
|
response = client.post('/api/auth/verify', headers={
|
||||||
|
'Authorization': f'Bearer {token}'
|
||||||
|
})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.get_json()
|
||||||
|
assert data['valid'] is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidation:
|
||||||
|
"""验证测试"""
|
||||||
|
|
||||||
|
def test_validate_email(self, client):
|
||||||
|
"""测试邮箱验证"""
|
||||||
|
from app.utils.validators import validate_email
|
||||||
|
|
||||||
|
assert validate_email('test@example.com') is True
|
||||||
|
assert validate_email('invalid-email') is False
|
||||||
|
|
||||||
|
def test_validate_username(self, client):
|
||||||
|
"""测试用户名验证"""
|
||||||
|
from app.utils.validators import validate_username
|
||||||
|
|
||||||
|
assert validate_username('testuser') is True
|
||||||
|
assert validate_username('ab') is False # 太短
|
||||||
|
assert validate_username('test-user') is True
|
||||||
|
assert validate_username('test@user') is False # 包含非法字符
|
||||||
|
|
||||||
|
def test_validate_uuid(self, client):
|
||||||
|
"""测试 UUID 验证"""
|
||||||
|
from app.utils.validators import validate_uuid
|
||||||
|
|
||||||
|
assert validate_uuid('550e8400-e29b-41d4-a716-446655440000') is True
|
||||||
|
assert validate_uuid('invalid-uuid') is False
|
||||||
96
tests/test_message_queue.py
Normal file
96
tests/test_message_queue.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
"""
|
||||||
|
消息队列单元测试
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from app.services.message_queue import MessageQueue
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestMessageQueue:
|
||||||
|
"""消息队列测试"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_redis(self):
|
||||||
|
"""模拟 Redis 客户端"""
|
||||||
|
with patch('app.services.message_queue.redis_client') as mock:
|
||||||
|
mock.hset.return_value = 1
|
||||||
|
mock.rpush.return_value = 1
|
||||||
|
mock.lpop.return_value = None
|
||||||
|
mock.hgetall.return_value = {}
|
||||||
|
mock.delete.return_value = 1
|
||||||
|
mock.llen.return_value = 0
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
def test_enqueue(self, mock_redis):
|
||||||
|
"""测试消息入队"""
|
||||||
|
message = MagicMock()
|
||||||
|
message.id = 'msg-123'
|
||||||
|
message.session_id = 'session-123'
|
||||||
|
message.sender_type = 'user'
|
||||||
|
message.sender_id = 'user-123'
|
||||||
|
message.content = 'Hello'
|
||||||
|
message.status = 'sent'
|
||||||
|
message.retry_count = 0
|
||||||
|
message.created_at = None
|
||||||
|
|
||||||
|
result = MessageQueue.enqueue(message)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_redis.hset.assert_called()
|
||||||
|
mock_redis.rpush.assert_called()
|
||||||
|
|
||||||
|
def test_dequeue_empty(self, mock_redis):
|
||||||
|
"""测试空队列出队"""
|
||||||
|
mock_redis.lpop.return_value = None
|
||||||
|
|
||||||
|
result = MessageQueue.dequeue()
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_dequeue_success(self, mock_redis):
|
||||||
|
"""测试成功出队"""
|
||||||
|
mock_redis.lpop.return_value = 'msg-123'
|
||||||
|
mock_redis.hgetall.return_value = {
|
||||||
|
'id': 'msg-123',
|
||||||
|
'session_id': 'session-123',
|
||||||
|
'content': 'Hello',
|
||||||
|
'retry_count': '0',
|
||||||
|
}
|
||||||
|
|
||||||
|
result = MessageQueue.dequeue()
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result['id'] == 'msg-123'
|
||||||
|
|
||||||
|
def test_ack(self, mock_redis):
|
||||||
|
"""测试消息确认"""
|
||||||
|
result = MessageQueue.ack('msg-123')
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_redis.delete.assert_called()
|
||||||
|
|
||||||
|
def test_retry_within_limit(self, mock_redis):
|
||||||
|
"""测试重试(未超限)"""
|
||||||
|
mock_redis.hget.return_value = '1'
|
||||||
|
|
||||||
|
result = MessageQueue.retry('msg-123')
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
mock_redis.hset.assert_called()
|
||||||
|
mock_redis.rpush.assert_called()
|
||||||
|
|
||||||
|
def test_retry_exceed_limit(self, mock_redis):
|
||||||
|
"""测试重试(超限)"""
|
||||||
|
mock_redis.hget.return_value = '3'
|
||||||
|
|
||||||
|
result = MessageQueue.retry('msg-123')
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
|
||||||
|
def test_get_pending_count(self, mock_redis):
|
||||||
|
"""测试获取待处理数量"""
|
||||||
|
mock_redis.llen.return_value = 5
|
||||||
|
|
||||||
|
count = MessageQueue.get_pending_count()
|
||||||
|
|
||||||
|
assert count == 5
|
||||||
176
tests/test_scheduler.py
Normal file
176
tests/test_scheduler.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
"""
|
||||||
|
调度器单元测试
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
from app.services.scheduler import (
|
||||||
|
AgentScheduler,
|
||||||
|
RoundRobinScheduler,
|
||||||
|
WeightedRoundRobinScheduler,
|
||||||
|
LeastConnectionsScheduler,
|
||||||
|
LeastResponseTimeScheduler,
|
||||||
|
CapabilityMatchScheduler,
|
||||||
|
)
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
|
class MockAgent:
|
||||||
|
"""模拟 Agent"""
|
||||||
|
def __init__(self, id, status='online', weight=10, priority=5,
|
||||||
|
current_sessions=0, connection_limit=5, capabilities=None,
|
||||||
|
last_heartbeat=None):
|
||||||
|
self.id = id
|
||||||
|
self.status = status
|
||||||
|
self.weight = weight
|
||||||
|
self.priority = priority
|
||||||
|
self.current_sessions = current_sessions
|
||||||
|
self.connection_limit = connection_limit
|
||||||
|
self.capabilities = capabilities or []
|
||||||
|
self.last_heartbeat = last_heartbeat
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoundRobinScheduler:
|
||||||
|
"""轮询调度器测试"""
|
||||||
|
|
||||||
|
def test_select_agent(self):
|
||||||
|
"""测试选择 Agent"""
|
||||||
|
scheduler = RoundRobinScheduler()
|
||||||
|
agents = [
|
||||||
|
MockAgent(id='agent1'),
|
||||||
|
MockAgent(id='agent2'),
|
||||||
|
MockAgent(id='agent3'),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 轮询选择
|
||||||
|
agent1 = scheduler.select_agent(agents)
|
||||||
|
agent2 = scheduler.select_agent(agents)
|
||||||
|
agent3 = scheduler.select_agent(agents)
|
||||||
|
agent4 = scheduler.select_agent(agents) # 循环回第一个
|
||||||
|
|
||||||
|
assert agent1.id == 'agent1'
|
||||||
|
assert agent2.id == 'agent2'
|
||||||
|
assert agent3.id == 'agent3'
|
||||||
|
assert agent4.id == 'agent1'
|
||||||
|
|
||||||
|
def test_no_online_agents(self):
|
||||||
|
"""测试无在线 Agent"""
|
||||||
|
scheduler = RoundRobinScheduler()
|
||||||
|
agents = [
|
||||||
|
MockAgent(id='agent1', status='offline'),
|
||||||
|
MockAgent(id='agent2', status='offline'),
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = scheduler.select_agent(agents)
|
||||||
|
assert agent is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestWeightedRoundRobinScheduler:
|
||||||
|
"""加权轮询调度器测试"""
|
||||||
|
|
||||||
|
def test_weight_distribution(self):
|
||||||
|
"""测试权重分布"""
|
||||||
|
scheduler = WeightedRoundRobinScheduler()
|
||||||
|
agents = [
|
||||||
|
MockAgent(id='agent1', weight=3),
|
||||||
|
MockAgent(id='agent2', weight=1),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 统计选择次数
|
||||||
|
counts = {'agent1': 0, 'agent2': 0}
|
||||||
|
for _ in range(100):
|
||||||
|
agent = scheduler.select_agent(agents)
|
||||||
|
counts[agent.id] += 1
|
||||||
|
|
||||||
|
# agent1 权重 3,agent2 权重 1,比例应该约 3:1
|
||||||
|
assert counts['agent1'] > counts['agent2']
|
||||||
|
|
||||||
|
def test_zero_weight(self):
|
||||||
|
"""测试零权重"""
|
||||||
|
scheduler = WeightedRoundRobinScheduler()
|
||||||
|
agents = [
|
||||||
|
MockAgent(id='agent1', weight=0),
|
||||||
|
MockAgent(id='agent2', weight=0),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 权重都为 0 时返回第一个
|
||||||
|
agent = scheduler.select_agent(agents)
|
||||||
|
assert agent is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestLeastConnectionsScheduler:
|
||||||
|
"""最少连接调度器测试"""
|
||||||
|
|
||||||
|
def test_select_least_connections(self):
|
||||||
|
"""测试选择最少连接"""
|
||||||
|
scheduler = LeastConnectionsScheduler()
|
||||||
|
agents = [
|
||||||
|
MockAgent(id='agent1', current_sessions=5),
|
||||||
|
MockAgent(id='agent2', current_sessions=2),
|
||||||
|
MockAgent(id='agent3', current_sessions=8),
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = scheduler.select_agent(agents)
|
||||||
|
assert agent.id == 'agent2'
|
||||||
|
|
||||||
|
def test_all_at_limit(self):
|
||||||
|
"""测试所有 Agent 都达上限"""
|
||||||
|
scheduler = LeastConnectionsScheduler()
|
||||||
|
agents = [
|
||||||
|
MockAgent(id='agent1', current_sessions=5, connection_limit=5),
|
||||||
|
MockAgent(id='agent2', current_sessions=5, connection_limit=5),
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = scheduler.select_agent(agents)
|
||||||
|
assert agent is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestCapabilityMatchScheduler:
|
||||||
|
"""能力匹配调度器测试"""
|
||||||
|
|
||||||
|
def test_match_capabilities(self):
|
||||||
|
"""测试能力匹配"""
|
||||||
|
scheduler = CapabilityMatchScheduler()
|
||||||
|
agents = [
|
||||||
|
MockAgent(id='agent1', capabilities=['chat', 'code']),
|
||||||
|
MockAgent(id='agent2', capabilities=['chat']),
|
||||||
|
MockAgent(id='agent3', capabilities=['code', 'translate']),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 需要 code 能力
|
||||||
|
agent = scheduler.select_agent(agents, {'capabilities': ['code']})
|
||||||
|
assert agent.id in ['agent1', 'agent3']
|
||||||
|
|
||||||
|
def test_no_matching_capabilities(self):
|
||||||
|
"""测试无匹配能力"""
|
||||||
|
scheduler = CapabilityMatchScheduler()
|
||||||
|
agents = [
|
||||||
|
MockAgent(id='agent1', capabilities=['chat']),
|
||||||
|
]
|
||||||
|
|
||||||
|
# 需要 code 能力但无 Agent 具备
|
||||||
|
agent = scheduler.select_agent(agents, {'capabilities': ['code']})
|
||||||
|
# 回退到加权轮询
|
||||||
|
assert agent is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentScheduler:
|
||||||
|
"""Agent 调度器工厂测试"""
|
||||||
|
|
||||||
|
def test_available_strategies(self):
|
||||||
|
"""测试可用策略"""
|
||||||
|
strategies = AgentScheduler.get_available_strategies()
|
||||||
|
|
||||||
|
assert 'round_robin' in strategies
|
||||||
|
assert 'weighted_round_robin' in strategies
|
||||||
|
assert 'least_connections' in strategies
|
||||||
|
assert 'least_response_time' in strategies
|
||||||
|
assert 'capability_match' in strategies
|
||||||
|
|
||||||
|
def test_default_strategy(self):
|
||||||
|
"""测试默认策略"""
|
||||||
|
scheduler = AgentScheduler()
|
||||||
|
assert scheduler.get_strategy() == 'weighted_round_robin'
|
||||||
|
|
||||||
|
def test_custom_strategy(self):
|
||||||
|
"""测试自定义策略"""
|
||||||
|
scheduler = AgentScheduler('round_robin')
|
||||||
|
assert scheduler.get_strategy() == 'round_robin'
|
||||||
Reference in New Issue
Block a user