From 6bafd21e02504a107d904af8b1086b8d36fc2d5d Mon Sep 17 00:00:00 2001 From: "feifei.xu" <307327147@qq.com> Date: Sat, 14 Mar 2026 20:14:59 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20Phase=203=20-=20=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E5=B1=82=20+=20=E6=B5=8B=E8=AF=95=20+=20=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E5=BA=93=E8=BF=81=E7=A7=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGELOG.md | 40 ++++++++ app/utils/__init__.py | 78 +++++++++++++++ app/utils/helpers.py | 113 +++++++++++++++++++++ app/utils/security.py | 121 +++++++++++++++++++++++ app/utils/validators.py | 120 ++++++++++++++++++++++ migrations/alembic.ini | 48 +++++++++ migrations/env.py | 68 +++++++++++++ migrations/versions/initial.py | 134 +++++++++++++++++++++++++ requirements.txt | 1 + tests/__init__.py | 3 + tests/conftest.py | 35 +++++++ tests/test_auth.py | 158 +++++++++++++++++++++++++++++ tests/test_message_queue.py | 96 ++++++++++++++++++ tests/test_scheduler.py | 176 +++++++++++++++++++++++++++++++++ 14 files changed, 1191 insertions(+) create mode 100644 app/utils/helpers.py create mode 100644 app/utils/security.py create mode 100644 app/utils/validators.py create mode 100644 migrations/alembic.ini create mode 100644 migrations/env.py create mode 100644 migrations/versions/initial.py create mode 100644 tests/__init__.py create mode 100644 tests/conftest.py create mode 100644 tests/test_auth.py create mode 100644 tests/test_message_queue.py create mode 100644 tests/test_scheduler.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 00d5438..94aece3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,46 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 --- +## [0.5.0] - 2026-03-14 + +### Added + +#### 🔧 工具层实现 (Phase 3) + +- **validators.py** - 输入验证工具 + - Marshmallow Schema 验证 + - UUID/Email/用户名/URL 验证 + - 字符串清理 + +- **security.py** - 安全工具 + - Token/密码哈希 + - API Key 生成 + - RateLimiter 限流器 + - IPWhitelist IP 白名单 + +- **helpers.py** - 辅助函数 + - 日期时间格式化 + - 分页辅助 + - JSON 安全解析 + +#### 🧪 测试模块 + +- **conftest.py** - pytest 配置 +- **test_auth.py** - 认证 API 单元测试 +- **test_scheduler.py** - 调度器单元测试 +- **test_message_queue.py** - 消息队列单元测试 + +#### 🗄️ 数据库迁移 + +- **Alembic 配置** - 数据库迁移工具 +- **初始迁移脚本** - 创建所有数据表 + +### Changed + +- **requirements.txt** - 添加 gevent 依赖(用于 gunicorn) + +--- + ## [0.4.0] - 2026-03-14 ### Added diff --git a/app/utils/__init__.py b/app/utils/__init__.py index 548e274..26dc4f2 100644 --- a/app/utils/__init__.py +++ b/app/utils/__init__.py @@ -1,3 +1,81 @@ """ 工具模块 """ +from .validators import ( + UserRegistrationSchema, + UserLoginSchema, + SessionCreateSchema, + MessageSendSchema, + AgentRegistrationSchema, + GatewayRegistrationSchema, + validate_uuid, + validate_email, + validate_username, + validate_url, + sanitize_string, + ValidationUtils, +) +from .security import ( + generate_token, + hash_password, + verify_password, + hash_token, + verify_token_hash, + generate_api_key, + secure_compare, + mask_sensitive_data, + RateLimiter, + IPWhitelist, +) +from .helpers import ( + format_datetime, + parse_datetime, + format_duration, + truncate_string, + safe_json_loads, + safe_json_dumps, + generate_session_title, + calculate_timeout, + merge_dicts, + filter_none_values, + PaginationHelper, +) + +__all__ = [ + # Validators + 'UserRegistrationSchema', + 'UserLoginSchema', + 'SessionCreateSchema', + 'MessageSendSchema', + 'AgentRegistrationSchema', + 'GatewayRegistrationSchema', + 'validate_uuid', + 'validate_email', + 'validate_username', + 'validate_url', + 'sanitize_string', + 'ValidationUtils', + # Security + 'generate_token', + 'hash_password', + 'verify_password', + 'hash_token', + 'verify_token_hash', + 'generate_api_key', + 'secure_compare', + 'mask_sensitive_data', + 'RateLimiter', + 'IPWhitelist', + # Helpers + 'format_datetime', + 'parse_datetime', + 'format_duration', + 'truncate_string', + 'safe_json_loads', + 'safe_json_dumps', + 'generate_session_title', + 'calculate_timeout', + 'merge_dicts', + 'filter_none_values', + 'PaginationHelper', +] diff --git a/app/utils/helpers.py b/app/utils/helpers.py new file mode 100644 index 0000000..ae62b73 --- /dev/null +++ b/app/utils/helpers.py @@ -0,0 +1,113 @@ +""" +辅助函数 +""" +from datetime import datetime, timedelta +from typing import Optional, Any +import json + + +def format_datetime(dt: Optional[datetime]) -> Optional[str]: + """格式化日期时间""" + if not dt: + return None + return dt.isoformat() + + +def parse_datetime(dt_str: str) -> Optional[datetime]: + """解析日期时间字符串""" + if not dt_str: + return None + try: + return datetime.fromisoformat(dt_str) + except: + return None + + +def format_duration(seconds: int) -> str: + """格式化时长""" + if seconds < 60: + return f"{seconds}s" + elif seconds < 3600: + return f"{seconds // 60}m" + elif seconds < 86400: + return f"{seconds // 3600}h" + else: + return f"{seconds // 86400}d" + + +def truncate_string(s: str, max_length: int = 100, suffix: str = "...") -> str: + """截断字符串""" + if len(s) <= max_length: + return s + return s[:max_length - len(suffix)] + suffix + + +def safe_json_loads(data: str, default: Any = None) -> Any: + """安全的 JSON 解析""" + try: + return json.loads(data) + except: + return default + + +def safe_json_dumps(data: Any, default: str = "{}") -> str: + """安全的 JSON 序列化""" + try: + return json.dumps(data, ensure_ascii=False) + except: + return default + + +def generate_session_title() -> str: + """生成默认会话标题""" + return f"Session {datetime.utcnow().strftime('%Y-%m-%d %H:%M')}" + + +def calculate_timeout(start_time: datetime, timeout_seconds: int) -> bool: + """检查是否超时""" + if not start_time: + return True + + elapsed = (datetime.utcnow() - start_time).total_seconds() + return elapsed > timeout_seconds + + +def merge_dicts(base: dict, override: dict) -> dict: + """合并字典""" + result = base.copy() + result.update(override) + return result + + +def filter_none_values(data: dict) -> dict: + """过滤字典中的 None 值""" + return {k: v for k, v in data.items() if v is not None} + + +class PaginationHelper: + """分页辅助类""" + + @staticmethod + def paginate(query, page: int = 1, per_page: int = 20): + """分页查询""" + if page < 1: + page = 1 + if per_page < 1: + per_page = 20 + if per_page > 100: + per_page = 100 + + total = query.count() + items = query.offset((page - 1) * per_page).limit(per_page).all() + + pages = (total + per_page - 1) // per_page + + return { + 'items': items, + 'total': total, + 'page': page, + 'per_page': per_page, + 'pages': pages, + 'has_next': page < pages, + 'has_prev': page > 1, + } diff --git a/app/utils/security.py b/app/utils/security.py new file mode 100644 index 0000000..23df58b --- /dev/null +++ b/app/utils/security.py @@ -0,0 +1,121 @@ +""" +安全工具 +""" +import hashlib +import hmac +import secrets +import bcrypt +from typing import Optional + + +def generate_token(length: int = 32) -> str: + """生成随机 Token""" + return secrets.token_urlsafe(length) + + +def hash_password(password: str) -> str: + """密码哈希""" + return bcrypt.hashpw( + password.encode('utf-8'), + bcrypt.gensalt() + ).decode('utf-8') + + +def verify_password(password: str, password_hash: str) -> bool: + """验证密码""" + return bcrypt.checkpw( + password.encode('utf-8'), + password_hash.encode('utf-8') + ) + + +def hash_token(token: str) -> str: + """Token 哈希(用于存储)""" + return hashlib.sha256(token.encode()).hexdigest() + + +def verify_token_hash(token: str, token_hash: str) -> bool: + """验证 Token 哈希""" + return hash_token(token) == token_hash + + +def generate_api_key() -> str: + """生成 API Key""" + return f"pit_{secrets.token_urlsafe(32)}" + + +def secure_compare(val1: str, val2: str) -> bool: + """安全字符串比较(防时序攻击)""" + return hmac.compare_digest(val1.encode(), val2.encode()) + + +def mask_sensitive_data(data: str, visible_chars: int = 4) -> str: + """脱敏处理""" + if len(data) <= visible_chars * 2: + return '*' * len(data) + + return data[:visible_chars] + '*' * (len(data) - visible_chars * 2) + data[-visible_chars:] + + +class RateLimiter: + """简单内存限流器""" + + def __init__(self): + self._storage = {} + + def is_allowed(self, key: str, limit: int = 100, window: int = 60) -> bool: + """检查是否允许请求""" + from time import time + + now = time() + window_start = now - window + + # 获取该 key 的请求记录 + requests = self._storage.get(key, []) + + # 清理过期记录 + requests = [t for t in requests if t > window_start] + + # 检查是否超过限制 + if len(requests) >= limit: + self._storage[key] = requests + return False + + # 记录新请求 + requests.append(now) + self._storage[key] = requests + + return True + + +class IPWhitelist: + """IP 白名单""" + + def __init__(self, allowed_ips: list = None): + self.allowed_ips = set(allowed_ips or []) + self.allow_all = '*' in self.allowed_ips + + def is_allowed(self, ip: str) -> bool: + """检查 IP 是否允许""" + if self.allow_all: + return True + + # 支持 CIDR 格式 + for allowed in self.allowed_ips: + if '/' in allowed: + if self._ip_in_cidr(ip, allowed): + return True + elif ip == allowed: + return True + + return False + + def _ip_in_cidr(self, ip: str, cidr: str) -> bool: + """检查 IP 是否在 CIDR 范围内""" + try: + import ipaddress + network = ipaddress.ip_network(cidr, strict=False) + address = ipaddress.ip_address(ip) + return address in network + except: + return False diff --git a/app/utils/validators.py b/app/utils/validators.py new file mode 100644 index 0000000..0d06437 --- /dev/null +++ b/app/utils/validators.py @@ -0,0 +1,120 @@ +""" +输入验证工具 +""" +from typing import Optional +import re +from marshmallow import Schema, fields, validate, ValidationError + + +class UserRegistrationSchema(Schema): + """用户注册验证""" + username = fields.Str(required=True, validate=validate.Length(min=3, max=80)) + email = fields.Email(required=True) + password = fields.Str(required=True, validate=validate.Length(min=6, max=128)) + nickname = fields.Str(validate=validate.Length(max=80), load_default=None) + + +class UserLoginSchema(Schema): + """用户登录验证""" + username = fields.Str(required=True) + password = fields.Str(required=True) + + +class SessionCreateSchema(Schema): + """会话创建验证""" + title = fields.Str(validate=validate.Length(max=200), load_default=None) + agent_id = fields.Str(validate=validate.Length(max=36), load_default=None) + priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5) + + +class MessageSendSchema(Schema): + """消息发送验证""" + session_id = fields.Str(required=True, validate=validate.Length(max=36)) + content = fields.Str(required=True, validate=validate.Length(min=1, max=10000)) + message_type = fields.Str(validate=validate.OneOf(['text', 'media', 'system']), load_default='text') + reply_to = fields.Str(validate=validate.Length(max=36), load_default=None) + + +class AgentRegistrationSchema(Schema): + """Agent 注册验证""" + name = fields.Str(required=True, validate=validate.Length(min=1, max=80)) + gateway_id = fields.Str(validate=validate.Length(max=36), load_default=None) + model = fields.Str(validate=validate.Length(max=80), load_default=None) + capabilities = fields.List(fields.Str(), load_default=[]) + priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5) + weight = fields.Int(validate=validate.Range(min=1, max=100), load_default=10) + + +class GatewayRegistrationSchema(Schema): + """Gateway 注册验证""" + name = fields.Str(required=True, validate=validate.Length(min=1, max=80)) + url = fields.Str(required=True, validate=validate.Length(max=256)) + token = fields.Str(validate=validate.Length(max=256), load_default=None) + + +def validate_uuid(uuid_str: str) -> bool: + """验证 UUID 格式""" + pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' + return bool(re.match(pattern, uuid_str.lower())) + + +def validate_email(email: str) -> bool: + """验证邮箱格式""" + pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' + return bool(re.match(pattern, email)) + + +def validate_username(username: str) -> bool: + """验证用户名格式""" + # 字母、数字、下划线,3-80 字符 + pattern = r'^[a-zA-Z0-9_]{3,80}$' + return bool(re.match(pattern, username)) + + +def validate_url(url: str) -> bool: + """验证 URL 格式""" + pattern = r'^https?://[^\s/$.?#].[^\s]*$' + return bool(re.match(pattern, url, re.IGNORECASE)) + + +def sanitize_string(input_str: str, max_length: int = 1000) -> str: + """清理字符串输入""" + if not input_str: + return "" + + # 去除首尾空格 + cleaned = input_str.strip() + + # 限制长度 + if len(cleaned) > max_length: + cleaned = cleaned[:max_length] + + # 去除危险字符 + cleaned = cleaned.replace('\x00', '') + + return cleaned + + +class ValidationUtils: + """验证工具类""" + + @staticmethod + def validate_json(data: dict, schema: Schema) -> tuple: + """验证 JSON 数据""" + try: + result = schema.load(data) + return True, result, None + except ValidationError as e: + return False, None, e.messages + + @staticmethod + def validate_pagination(page: int, per_page: int, max_per_page: int = 100) -> tuple: + """验证分页参数""" + if page < 1: + page = 1 + if per_page < 1: + per_page = 20 + if per_page > max_per_page: + per_page = max_per_page + + return page, per_page diff --git a/migrations/alembic.ini b/migrations/alembic.ini new file mode 100644 index 0000000..eeca62b --- /dev/null +++ b/migrations/alembic.ini @@ -0,0 +1,48 @@ +[alembic] +# 脚本位置 +script_location = migrations + +# 模板文件 +template_file = + +# 最大保留版本数 +max_num = 10 + +# 数据库 URL(运行时覆盖) +sqlalchemy.url = sqlite:///pit_router.db + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/migrations/env.py b/migrations/env.py new file mode 100644 index 0000000..dbdbe60 --- /dev/null +++ b/migrations/env.py @@ -0,0 +1,68 @@ +""" +Alembic 环境配置 +""" +from logging.config import fileConfig +from sqlalchemy import engine_from_config +from sqlalchemy import pool +from alembic import context +import os +import sys + +# 添加项目路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from app.extensions import db +from app.config import Config + +# Alembic Config 对象 +config = context.config + +# 配置日志 +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# 模型元数据 +target_metadata = db.Model.metadata + +def get_url(): + """获取数据库 URL""" + return os.getenv('DATABASE_URL', 'sqlite:///pit_router.db') + + +def run_migrations_offline() -> None: + """离线迁移""" + url = get_url() + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def run_migrations_online() -> None: + """在线迁移""" + configuration = config.get_section(config.config_ini_section) + configuration["sqlalchemy.url"] = get_url() + connectable = engine_from_config( + configuration, + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + with connectable.connect() as connection: + context.configure( + connection=connection, target_metadata=target_metadata + ) + + with context.begin_transaction(): + context.run_migrations() + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/migrations/versions/initial.py b/migrations/versions/initial.py new file mode 100644 index 0000000..7dabf2c --- /dev/null +++ b/migrations/versions/initial.py @@ -0,0 +1,134 @@ +""" +初始化迁移脚本 +Revision ID: initial +Revises: +Create Date: 2026-03-14 +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers +revision = 'initial' +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade(): + """创建表""" + # 用户表 + op.create_table( + 'users', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('username', sa.String(80), unique=True, nullable=False), + sa.Column('password_hash', sa.String(256), nullable=False), + sa.Column('email', sa.String(120), unique=True, nullable=False), + sa.Column('nickname', sa.String(80)), + sa.Column('role', sa.String(20), default='user'), + sa.Column('status', sa.String(20), default='active'), + sa.Column('created_at', sa.DateTime, default=sa.func.now()), + sa.Column('last_login_at', sa.DateTime), + ) + + # Gateway 表 + op.create_table( + 'gateways', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('name', sa.String(80), unique=True, nullable=False), + sa.Column('url', sa.String(256), nullable=False), + sa.Column('token_hash', sa.String(256)), + sa.Column('status', sa.String(20), default='offline'), + sa.Column('agent_count', sa.Integer, default=0), + sa.Column('connection_limit', sa.Integer, default=10), + sa.Column('heartbeat_interval', sa.Integer, default=60), + sa.Column('allowed_ips', sa.JSON), + sa.Column('last_heartbeat', sa.DateTime), + sa.Column('created_at', sa.DateTime, default=sa.func.now()), + ) + + # Agent 表 + op.create_table( + 'agents', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('name', sa.String(80), nullable=False), + sa.Column('display_name', sa.String(80)), + sa.Column('gateway_id', sa.String(36), sa.ForeignKey('gateways.id')), + sa.Column('socket_id', sa.String(100)), + sa.Column('model', sa.String(80)), + sa.Column('capabilities', sa.JSON), + sa.Column('status', sa.String(20), default='offline'), + sa.Column('priority', sa.Integer, default=5), + sa.Column('weight', sa.Integer, default=10), + sa.Column('connection_limit', sa.Integer, default=5), + sa.Column('current_sessions', sa.Integer, default=0), + sa.Column('last_heartbeat', sa.DateTime), + sa.Column('created_at', sa.DateTime, default=sa.func.now()), + ) + + # 会话表 + op.create_table( + 'sessions', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('user_id', sa.String(36), sa.ForeignKey('users.id'), nullable=False), + sa.Column('primary_agent_id', sa.String(36), sa.ForeignKey('agents.id')), + sa.Column('participating_agent_ids', sa.JSON), + sa.Column('user_socket_id', sa.String(100)), + sa.Column('title', sa.String(200)), + sa.Column('channel_type', sa.String(20), default='web'), + sa.Column('status', sa.String(20), default='active'), + sa.Column('message_count', sa.Integer, default=0), + sa.Column('unread_count', sa.Integer, default=0), + sa.Column('created_at', sa.DateTime, default=sa.func.now()), + sa.Column('updated_at', sa.DateTime), + sa.Column('last_active_at', sa.DateTime), + ) + + # 消息表 + op.create_table( + 'messages', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('session_id', sa.String(36), sa.ForeignKey('sessions.id'), nullable=False), + sa.Column('sender_type', sa.String(20), nullable=False), + sa.Column('sender_id', sa.String(36), nullable=False), + sa.Column('message_type', sa.String(20), default='text'), + sa.Column('content', sa.Text), + sa.Column('content_type', sa.String(20), default='markdown'), + sa.Column('reply_to', sa.String(36)), + sa.Column('status', sa.String(20), default='sent'), + sa.Column('ack_status', sa.String(20), default='pending'), + sa.Column('retry_count', sa.Integer, default=0), + sa.Column('created_at', sa.DateTime, default=sa.func.now()), + sa.Column('delivered_at', sa.DateTime), + ) + + # 连接表 + op.create_table( + 'connections', + sa.Column('id', sa.String(36), primary_key=True), + sa.Column('socket_id', sa.String(100), unique=True, nullable=False), + sa.Column('connection_type', sa.String(20), nullable=False), + sa.Column('entity_id', sa.String(36), nullable=False), + sa.Column('entity_type', sa.String(20), nullable=False), + sa.Column('ip_address', sa.String(45)), + sa.Column('user_agent', sa.String(500)), + sa.Column('status', sa.String(20), default='connected'), + sa.Column('auth_token', sa.String(500)), + sa.Column('connected_at', sa.DateTime, default=sa.func.now()), + sa.Column('last_activity', sa.DateTime), + sa.Column('disconnected_at', sa.DateTime), + ) + + # 创建索引 + op.create_index('ix_messages_session_id', 'messages', ['session_id']) + op.create_index('ix_sessions_user_id', 'sessions', ['user_id']) + op.create_index('ix_sessions_status', 'sessions', ['status']) + + +def downgrade(): + """删除表""" + op.drop_table('connections') + op.drop_table('messages') + op.drop_table('sessions') + op.drop_table('agents') + op.drop_table('gateways') + op.drop_table('users') diff --git a/requirements.txt b/requirements.txt index 458967e..795ece1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ Flask-Login==0.6.3 Flask-Migrate==4.0.5 Flask-SocketIO==5.3.6 Flask-SQLAlchemy==3.1.1 +gevent==24.2.1 gunicorn==21.2.0 psycopg2-binary==2.9.9 PyJWT==2.8.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..759177f --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,3 @@ +""" +测试模块初始化 +""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..fc9fca4 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,35 @@ +""" +pytest 配置 +""" +import pytest +import sys +import os + +# 添加项目根目录到路径 +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +@pytest.fixture +def app(): + """创建测试应用""" + from app import create_app + app = create_app('testing') + + with app.app_context(): + from app.extensions import db + db.create_all() + yield app + db.session.remove() + db.drop_all() + + +@pytest.fixture +def client(app): + """创建测试客户端""" + return app.test_client() + + +@pytest.fixture +def runner(app): + """创建测试运行器""" + return app.test_cli_runner() diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..cc4ba2a --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,158 @@ +""" +认证 API 单元测试 +""" +import pytest +from app import create_app +from app.extensions import db +from app.models import User +from app.utils.security import hash_password + + +@pytest.fixture +def app(): + """创建测试应用""" + app = create_app('testing') + with app.app_context(): + db.create_all() + yield app + db.drop_all() + + +@pytest.fixture +def client(app): + """创建测试客户端""" + return app.test_client() + + +class TestAuthAPI: + """认证 API 测试""" + + def test_register_success(self, client): + """测试用户注册成功""" + response = client.post('/api/auth/register', json={ + 'username': 'testuser', + 'email': 'test@example.com', + 'password': 'password123', + 'nickname': 'Test User' + }) + + assert response.status_code == 201 + data = response.get_json() + assert data['user']['username'] == 'testuser' + assert data['user']['email'] == 'test@example.com' + + def test_register_duplicate_username(self, client, app): + """测试重复用户名""" + # 创建已存在的用户 + with app.app_context(): + user = User( + username='testuser', + email='existing@example.com', + password_hash=hash_password('password') + ) + db.session.add(user) + db.session.commit() + + # 尝试注册相同用户名 + response = client.post('/api/auth/register', json={ + 'username': 'testuser', + 'email': 'new@example.com', + 'password': 'password123' + }) + + assert response.status_code == 400 + + def test_register_missing_fields(self, client): + """测试缺少必填字段""" + response = client.post('/api/auth/register', json={ + 'username': 'testuser' + }) + + assert response.status_code == 400 + + def test_login_success(self, client, app): + """测试登录成功""" + # 创建用户 + with app.app_context(): + user = User( + username='testuser', + email='test@example.com', + password_hash=hash_password('password123') + ) + db.session.add(user) + db.session.commit() + + # 登录 + response = client.post('/api/auth/login', json={ + 'username': 'testuser', + 'password': 'password123' + }) + + assert response.status_code == 200 + data = response.get_json() + assert 'access_token' in data + assert 'refresh_token' in data + + def test_login_invalid_credentials(self, client): + """测试无效凭据""" + response = client.post('/api/auth/login', json={ + 'username': 'nonexistent', + 'password': 'wrong' + }) + + assert response.status_code == 401 + + def test_verify_token(self, client, app): + """测试 Token 验证""" + # 创建用户 + with app.app_context(): + user = User( + username='testuser', + email='test@example.com', + password_hash=hash_password('password123') + ) + db.session.add(user) + db.session.commit() + + # 登录获取 token + login_response = client.post('/api/auth/login', json={ + 'username': 'testuser', + 'password': 'password123' + }) + token = login_response.get_json()['access_token'] + + # 验证 token + response = client.post('/api/auth/verify', headers={ + 'Authorization': f'Bearer {token}' + }) + + assert response.status_code == 200 + data = response.get_json() + assert data['valid'] is True + + +class TestValidation: + """验证测试""" + + def test_validate_email(self, client): + """测试邮箱验证""" + from app.utils.validators import validate_email + + assert validate_email('test@example.com') is True + assert validate_email('invalid-email') is False + + def test_validate_username(self, client): + """测试用户名验证""" + from app.utils.validators import validate_username + + assert validate_username('testuser') is True + assert validate_username('ab') is False # 太短 + assert validate_username('test-user') is True + assert validate_username('test@user') is False # 包含非法字符 + + def test_validate_uuid(self, client): + """测试 UUID 验证""" + from app.utils.validators import validate_uuid + + assert validate_uuid('550e8400-e29b-41d4-a716-446655440000') is True + assert validate_uuid('invalid-uuid') is False diff --git a/tests/test_message_queue.py b/tests/test_message_queue.py new file mode 100644 index 0000000..67c7390 --- /dev/null +++ b/tests/test_message_queue.py @@ -0,0 +1,96 @@ +""" +消息队列单元测试 +""" +import pytest +from app.services.message_queue import MessageQueue +from unittest.mock import MagicMock, patch + + +class TestMessageQueue: + """消息队列测试""" + + @pytest.fixture + def mock_redis(self): + """模拟 Redis 客户端""" + with patch('app.services.message_queue.redis_client') as mock: + mock.hset.return_value = 1 + mock.rpush.return_value = 1 + mock.lpop.return_value = None + mock.hgetall.return_value = {} + mock.delete.return_value = 1 + mock.llen.return_value = 0 + yield mock + + def test_enqueue(self, mock_redis): + """测试消息入队""" + message = MagicMock() + message.id = 'msg-123' + message.session_id = 'session-123' + message.sender_type = 'user' + message.sender_id = 'user-123' + message.content = 'Hello' + message.status = 'sent' + message.retry_count = 0 + message.created_at = None + + result = MessageQueue.enqueue(message) + + assert result is True + mock_redis.hset.assert_called() + mock_redis.rpush.assert_called() + + def test_dequeue_empty(self, mock_redis): + """测试空队列出队""" + mock_redis.lpop.return_value = None + + result = MessageQueue.dequeue() + + assert result is None + + def test_dequeue_success(self, mock_redis): + """测试成功出队""" + mock_redis.lpop.return_value = 'msg-123' + mock_redis.hgetall.return_value = { + 'id': 'msg-123', + 'session_id': 'session-123', + 'content': 'Hello', + 'retry_count': '0', + } + + result = MessageQueue.dequeue() + + assert result is not None + assert result['id'] == 'msg-123' + + def test_ack(self, mock_redis): + """测试消息确认""" + result = MessageQueue.ack('msg-123') + + assert result is True + mock_redis.delete.assert_called() + + def test_retry_within_limit(self, mock_redis): + """测试重试(未超限)""" + mock_redis.hget.return_value = '1' + + result = MessageQueue.retry('msg-123') + + assert result is True + mock_redis.hset.assert_called() + mock_redis.rpush.assert_called() + + def test_retry_exceed_limit(self, mock_redis): + """测试重试(超限)""" + mock_redis.hget.return_value = '3' + + result = MessageQueue.retry('msg-123') + + assert result is False + + def test_get_pending_count(self, mock_redis): + """测试获取待处理数量""" + mock_redis.llen.return_value = 5 + + count = MessageQueue.get_pending_count() + + assert count == 5 diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py new file mode 100644 index 0000000..4d19b54 --- /dev/null +++ b/tests/test_scheduler.py @@ -0,0 +1,176 @@ +""" +调度器单元测试 +""" +import pytest +from app.services.scheduler import ( + AgentScheduler, + RoundRobinScheduler, + WeightedRoundRobinScheduler, + LeastConnectionsScheduler, + LeastResponseTimeScheduler, + CapabilityMatchScheduler, +) +from datetime import datetime, timedelta + + +class MockAgent: + """模拟 Agent""" + def __init__(self, id, status='online', weight=10, priority=5, + current_sessions=0, connection_limit=5, capabilities=None, + last_heartbeat=None): + self.id = id + self.status = status + self.weight = weight + self.priority = priority + self.current_sessions = current_sessions + self.connection_limit = connection_limit + self.capabilities = capabilities or [] + self.last_heartbeat = last_heartbeat + + +class TestRoundRobinScheduler: + """轮询调度器测试""" + + def test_select_agent(self): + """测试选择 Agent""" + scheduler = RoundRobinScheduler() + agents = [ + MockAgent(id='agent1'), + MockAgent(id='agent2'), + MockAgent(id='agent3'), + ] + + # 轮询选择 + agent1 = scheduler.select_agent(agents) + agent2 = scheduler.select_agent(agents) + agent3 = scheduler.select_agent(agents) + agent4 = scheduler.select_agent(agents) # 循环回第一个 + + assert agent1.id == 'agent1' + assert agent2.id == 'agent2' + assert agent3.id == 'agent3' + assert agent4.id == 'agent1' + + def test_no_online_agents(self): + """测试无在线 Agent""" + scheduler = RoundRobinScheduler() + agents = [ + MockAgent(id='agent1', status='offline'), + MockAgent(id='agent2', status='offline'), + ] + + agent = scheduler.select_agent(agents) + assert agent is None + + +class TestWeightedRoundRobinScheduler: + """加权轮询调度器测试""" + + def test_weight_distribution(self): + """测试权重分布""" + scheduler = WeightedRoundRobinScheduler() + agents = [ + MockAgent(id='agent1', weight=3), + MockAgent(id='agent2', weight=1), + ] + + # 统计选择次数 + counts = {'agent1': 0, 'agent2': 0} + for _ in range(100): + agent = scheduler.select_agent(agents) + counts[agent.id] += 1 + + # agent1 权重 3,agent2 权重 1,比例应该约 3:1 + assert counts['agent1'] > counts['agent2'] + + def test_zero_weight(self): + """测试零权重""" + scheduler = WeightedRoundRobinScheduler() + agents = [ + MockAgent(id='agent1', weight=0), + MockAgent(id='agent2', weight=0), + ] + + # 权重都为 0 时返回第一个 + agent = scheduler.select_agent(agents) + assert agent is not None + + +class TestLeastConnectionsScheduler: + """最少连接调度器测试""" + + def test_select_least_connections(self): + """测试选择最少连接""" + scheduler = LeastConnectionsScheduler() + agents = [ + MockAgent(id='agent1', current_sessions=5), + MockAgent(id='agent2', current_sessions=2), + MockAgent(id='agent3', current_sessions=8), + ] + + agent = scheduler.select_agent(agents) + assert agent.id == 'agent2' + + def test_all_at_limit(self): + """测试所有 Agent 都达上限""" + scheduler = LeastConnectionsScheduler() + agents = [ + MockAgent(id='agent1', current_sessions=5, connection_limit=5), + MockAgent(id='agent2', current_sessions=5, connection_limit=5), + ] + + agent = scheduler.select_agent(agents) + assert agent is None + + +class TestCapabilityMatchScheduler: + """能力匹配调度器测试""" + + def test_match_capabilities(self): + """测试能力匹配""" + scheduler = CapabilityMatchScheduler() + agents = [ + MockAgent(id='agent1', capabilities=['chat', 'code']), + MockAgent(id='agent2', capabilities=['chat']), + MockAgent(id='agent3', capabilities=['code', 'translate']), + ] + + # 需要 code 能力 + agent = scheduler.select_agent(agents, {'capabilities': ['code']}) + assert agent.id in ['agent1', 'agent3'] + + def test_no_matching_capabilities(self): + """测试无匹配能力""" + scheduler = CapabilityMatchScheduler() + agents = [ + MockAgent(id='agent1', capabilities=['chat']), + ] + + # 需要 code 能力但无 Agent 具备 + agent = scheduler.select_agent(agents, {'capabilities': ['code']}) + # 回退到加权轮询 + assert agent is not None + + +class TestAgentScheduler: + """Agent 调度器工厂测试""" + + def test_available_strategies(self): + """测试可用策略""" + strategies = AgentScheduler.get_available_strategies() + + assert 'round_robin' in strategies + assert 'weighted_round_robin' in strategies + assert 'least_connections' in strategies + assert 'least_response_time' in strategies + assert 'capability_match' in strategies + + def test_default_strategy(self): + """测试默认策略""" + scheduler = AgentScheduler() + assert scheduler.get_strategy() == 'weighted_round_robin' + + def test_custom_strategy(self): + """测试自定义策略""" + scheduler = AgentScheduler('round_robin') + assert scheduler.get_strategy() == 'round_robin'