121 lines
3.9 KiB
Python
121 lines
3.9 KiB
Python
|
|
"""
|
|||
|
|
输入验证工具
|
|||
|
|
"""
|
|||
|
|
from typing import Optional
|
|||
|
|
import re
|
|||
|
|
from marshmallow import Schema, fields, validate, ValidationError
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UserRegistrationSchema(Schema):
|
|||
|
|
"""用户注册验证"""
|
|||
|
|
username = fields.Str(required=True, validate=validate.Length(min=3, max=80))
|
|||
|
|
email = fields.Email(required=True)
|
|||
|
|
password = fields.Str(required=True, validate=validate.Length(min=6, max=128))
|
|||
|
|
nickname = fields.Str(validate=validate.Length(max=80), load_default=None)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class UserLoginSchema(Schema):
|
|||
|
|
"""用户登录验证"""
|
|||
|
|
username = fields.Str(required=True)
|
|||
|
|
password = fields.Str(required=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SessionCreateSchema(Schema):
|
|||
|
|
"""会话创建验证"""
|
|||
|
|
title = fields.Str(validate=validate.Length(max=200), load_default=None)
|
|||
|
|
agent_id = fields.Str(validate=validate.Length(max=36), load_default=None)
|
|||
|
|
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class MessageSendSchema(Schema):
|
|||
|
|
"""消息发送验证"""
|
|||
|
|
session_id = fields.Str(required=True, validate=validate.Length(max=36))
|
|||
|
|
content = fields.Str(required=True, validate=validate.Length(min=1, max=10000))
|
|||
|
|
message_type = fields.Str(validate=validate.OneOf(['text', 'media', 'system']), load_default='text')
|
|||
|
|
reply_to = fields.Str(validate=validate.Length(max=36), load_default=None)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class AgentRegistrationSchema(Schema):
|
|||
|
|
"""Agent 注册验证"""
|
|||
|
|
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
|
|||
|
|
gateway_id = fields.Str(validate=validate.Length(max=36), load_default=None)
|
|||
|
|
model = fields.Str(validate=validate.Length(max=80), load_default=None)
|
|||
|
|
capabilities = fields.List(fields.Str(), load_default=[])
|
|||
|
|
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
|
|||
|
|
weight = fields.Int(validate=validate.Range(min=1, max=100), load_default=10)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class GatewayRegistrationSchema(Schema):
|
|||
|
|
"""Gateway 注册验证"""
|
|||
|
|
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
|
|||
|
|
url = fields.Str(required=True, validate=validate.Length(max=256))
|
|||
|
|
token = fields.Str(validate=validate.Length(max=256), load_default=None)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def validate_uuid(uuid_str: str) -> bool:
|
|||
|
|
"""验证 UUID 格式"""
|
|||
|
|
pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$'
|
|||
|
|
return bool(re.match(pattern, uuid_str.lower()))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def validate_email(email: str) -> bool:
|
|||
|
|
"""验证邮箱格式"""
|
|||
|
|
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
|||
|
|
return bool(re.match(pattern, email))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def validate_username(username: str) -> bool:
|
|||
|
|
"""验证用户名格式"""
|
|||
|
|
# 字母、数字、下划线,3-80 字符
|
|||
|
|
pattern = r'^[a-zA-Z0-9_]{3,80}$'
|
|||
|
|
return bool(re.match(pattern, username))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def validate_url(url: str) -> bool:
|
|||
|
|
"""验证 URL 格式"""
|
|||
|
|
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
|
|||
|
|
return bool(re.match(pattern, url, re.IGNORECASE))
|
|||
|
|
|
|||
|
|
|
|||
|
|
def sanitize_string(input_str: str, max_length: int = 1000) -> str:
|
|||
|
|
"""清理字符串输入"""
|
|||
|
|
if not input_str:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
# 去除首尾空格
|
|||
|
|
cleaned = input_str.strip()
|
|||
|
|
|
|||
|
|
# 限制长度
|
|||
|
|
if len(cleaned) > max_length:
|
|||
|
|
cleaned = cleaned[:max_length]
|
|||
|
|
|
|||
|
|
# 去除危险字符
|
|||
|
|
cleaned = cleaned.replace('\x00', '')
|
|||
|
|
|
|||
|
|
return cleaned
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ValidationUtils:
|
|||
|
|
"""验证工具类"""
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def validate_json(data: dict, schema: Schema) -> tuple:
|
|||
|
|
"""验证 JSON 数据"""
|
|||
|
|
try:
|
|||
|
|
result = schema.load(data)
|
|||
|
|
return True, result, None
|
|||
|
|
except ValidationError as e:
|
|||
|
|
return False, None, e.messages
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def validate_pagination(page: int, per_page: int, max_per_page: int = 100) -> tuple:
|
|||
|
|
"""验证分页参数"""
|
|||
|
|
if page < 1:
|
|||
|
|
page = 1
|
|||
|
|
if per_page < 1:
|
|||
|
|
per_page = 20
|
|||
|
|
if per_page > max_per_page:
|
|||
|
|
per_page = max_per_page
|
|||
|
|
|
|||
|
|
return page, per_page
|