121 lines
3.9 KiB
Python
121 lines
3.9 KiB
Python
"""
|
||
输入验证工具
|
||
"""
|
||
from typing import Optional
|
||
import re
|
||
from marshmallow import Schema, fields, validate, ValidationError
|
||
|
||
|
||
class UserRegistrationSchema(Schema):
|
||
"""用户注册验证"""
|
||
username = fields.Str(required=True, validate=validate.Length(min=3, max=80))
|
||
email = fields.Email(required=True)
|
||
password = fields.Str(required=True, validate=validate.Length(min=6, max=128))
|
||
nickname = fields.Str(validate=validate.Length(max=80), load_default=None)
|
||
|
||
|
||
class UserLoginSchema(Schema):
|
||
"""用户登录验证"""
|
||
username = fields.Str(required=True)
|
||
password = fields.Str(required=True)
|
||
|
||
|
||
class SessionCreateSchema(Schema):
|
||
"""会话创建验证"""
|
||
title = fields.Str(validate=validate.Length(max=200), load_default=None)
|
||
agent_id = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
|
||
|
||
|
||
class MessageSendSchema(Schema):
|
||
"""消息发送验证"""
|
||
session_id = fields.Str(required=True, validate=validate.Length(max=36))
|
||
content = fields.Str(required=True, validate=validate.Length(min=1, max=10000))
|
||
message_type = fields.Str(validate=validate.OneOf(['text', 'media', 'system']), load_default='text')
|
||
reply_to = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||
|
||
|
||
class AgentRegistrationSchema(Schema):
|
||
"""Agent 注册验证"""
|
||
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
|
||
gateway_id = fields.Str(validate=validate.Length(max=36), load_default=None)
|
||
model = fields.Str(validate=validate.Length(max=80), load_default=None)
|
||
capabilities = fields.List(fields.Str(), load_default=[])
|
||
priority = fields.Int(validate=validate.Range(min=1, max=10), load_default=5)
|
||
weight = fields.Int(validate=validate.Range(min=1, max=100), load_default=10)
|
||
|
||
|
||
class GatewayRegistrationSchema(Schema):
|
||
"""Gateway 注册验证"""
|
||
name = fields.Str(required=True, validate=validate.Length(min=1, max=80))
|
||
url = fields.Str(required=True, validate=validate.Length(max=256))
|
||
token = fields.Str(validate=validate.Length(max=256), load_default=None)
|
||
|
||
|
||
def validate_uuid(uuid_str: str) -> bool:
|
||
"""验证 UUID 格式"""
|
||
pattern = r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$'
|
||
return bool(re.match(pattern, uuid_str.lower()))
|
||
|
||
|
||
def validate_email(email: str) -> bool:
|
||
"""验证邮箱格式"""
|
||
pattern = r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$'
|
||
return bool(re.match(pattern, email))
|
||
|
||
|
||
def validate_username(username: str) -> bool:
|
||
"""验证用户名格式"""
|
||
# 字母、数字、下划线,3-80 字符
|
||
pattern = r'^[a-zA-Z0-9_]{3,80}$'
|
||
return bool(re.match(pattern, username))
|
||
|
||
|
||
def validate_url(url: str) -> bool:
|
||
"""验证 URL 格式"""
|
||
pattern = r'^https?://[^\s/$.?#].[^\s]*$'
|
||
return bool(re.match(pattern, url, re.IGNORECASE))
|
||
|
||
|
||
def sanitize_string(input_str: str, max_length: int = 1000) -> str:
|
||
"""清理字符串输入"""
|
||
if not input_str:
|
||
return ""
|
||
|
||
# 去除首尾空格
|
||
cleaned = input_str.strip()
|
||
|
||
# 限制长度
|
||
if len(cleaned) > max_length:
|
||
cleaned = cleaned[:max_length]
|
||
|
||
# 去除危险字符
|
||
cleaned = cleaned.replace('\x00', '')
|
||
|
||
return cleaned
|
||
|
||
|
||
class ValidationUtils:
|
||
"""验证工具类"""
|
||
|
||
@staticmethod
|
||
def validate_json(data: dict, schema: Schema) -> tuple:
|
||
"""验证 JSON 数据"""
|
||
try:
|
||
result = schema.load(data)
|
||
return True, result, None
|
||
except ValidationError as e:
|
||
return False, None, e.messages
|
||
|
||
@staticmethod
|
||
def validate_pagination(page: int, per_page: int, max_per_page: int = 100) -> tuple:
|
||
"""验证分页参数"""
|
||
if page < 1:
|
||
page = 1
|
||
if per_page < 1:
|
||
per_page = 20
|
||
if per_page > max_per_page:
|
||
per_page = max_per_page
|
||
|
||
return page, per_page
|