""" 输入验证工具 """ from typing import Optional import re from marshmallow import Schema, fields, validate, ValidationError class UserRegistrationSchema(Schema): """用户注册验证""" username = fields.Str(required=True, validate=validate.Length(min=3, max=80)) email = fields.Email(required=True) password = fields.Str(required=True, validate=validate.Length(min=6, max=128)) nickname = fields.Str(validate=validate.Length(max=80), load_default=None) class UserLoginSchema(Schema): """用户登录验证""" username = fields.Str(required=True) password = fields.Str(required=True) class SessionCreateSchema(Schema): """会话创建验证""" title = fields.Str(validate=validate.Length(max=200), load_default=None) agent_id = fields.Str(validate=validate.Length(max=36), load_default=None) priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5) class MessageSendSchema(Schema): """消息发送验证""" session_id = fields.Str(required=True, validate=validate.Length(max=36)) content = fields.Str(required=True, validate=validate.Length(min=1, max=10000)) message_type = fields.Str(validate=validate.OneOf(['text', 'media', 'system']), load_default='text') reply_to = fields.Str(validate=validate.Length(max=36), load_default=None) class AgentRegistrationSchema(Schema): """Agent 注册验证""" name = fields.Str(required=True, validate=validate.Length(min=1, max=80)) gateway_id = fields.Str(validate=validate.Length(max=36), load_default=None) model = fields.Str(validate=validate.Length(max=80), load_default=None) capabilities = fields.List(fields.Str(), load_default=[]) priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5) weight = fields.Int(validate=validate.Range(min=1, max=100), load_default=10) class GatewayRegistrationSchema(Schema): """Gateway 注册验证""" name = fields.Str(required=True, validate=validate.Length(min=1, max=80)) url = fields.Str(required=True, validate=validate.Length(max=256)) token = fields.Str(validate=validate.Length(max=256), load_default=None) def validate_uuid(uuid_str: str) -> bool: """验证 UUID 格式""" pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$' return bool(re.match(pattern, uuid_str.lower())) def validate_email(email: str) -> bool: """验证邮箱格式""" pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$' return bool(re.match(pattern, email)) def validate_username(username: str) -> bool: """验证用户名格式""" # 字母、数字、下划线,3-80 字符 pattern = r'^[a-zA-Z0-9_]{3,80}$' return bool(re.match(pattern, username)) def validate_url(url: str) -> bool: """验证 URL 格式""" pattern = r'^https?://[^\s/$.?#].[^\s]*$' return bool(re.match(pattern, url, re.IGNORECASE)) def sanitize_string(input_str: str, max_length: int = 1000) -> str: """清理字符串输入""" if not input_str: return "" # 去除首尾空格 cleaned = input_str.strip() # 限制长度 if len(cleaned) > max_length: cleaned = cleaned[:max_length] # 去除危险字符 cleaned = cleaned.replace('\x00', '') return cleaned class ValidationUtils: """验证工具类""" @staticmethod def validate_json(data: dict, schema: Schema) -> tuple: """验证 JSON 数据""" try: result = schema.load(data) return True, result, None except ValidationError as e: return False, None, e.messages @staticmethod def validate_pagination(page: int, per_page: int, max_per_page: int = 100) -> tuple: """验证分页参数""" if page < 1: page = 1 if per_page < 1: per_page = 20 if per_page > max_per_page: per_page = max_per_page return page, per_page