""" 安全工具 """ import hashlib import hmac import secrets import bcrypt from typing import Optional def generate_token(length: int = 32) -> str: """生成随机 Token""" return secrets.token_urlsafe(length) def hash_password(password: str) -> str: """密码哈希""" return bcrypt.hashpw( password.encode('utf-8'), bcrypt.gensalt() ).decode('utf-8') def verify_password(password: str, password_hash: str) -> bool: """验证密码""" return bcrypt.checkpw( password.encode('utf-8'), password_hash.encode('utf-8') ) def hash_token(token: str) -> str: """Token 哈希(用于存储)""" return hashlib.sha256(token.encode()).hexdigest() def verify_token_hash(token: str, token_hash: str) -> bool: """验证 Token 哈希""" return hash_token(token) == token_hash def generate_api_key() -> str: """生成 API Key""" return f"pit_{secrets.token_urlsafe(32)}" def secure_compare(val1: str, val2: str) -> bool: """安全字符串比较(防时序攻击)""" return hmac.compare_digest(val1.encode(), val2.encode()) def mask_sensitive_data(data: str, visible_chars: int = 4) -> str: """脱敏处理""" if len(data) <= visible_chars * 2: return '*' * len(data) return data[:visible_chars] + '*' * (len(data) - visible_chars * 2) + data[-visible_chars:] class RateLimiter: """简单内存限流器""" def __init__(self): self._storage = {} def is_allowed(self, key: str, limit: int = 100, window: int = 60) -> bool: """检查是否允许请求""" from time import time now = time() window_start = now - window # 获取该 key 的请求记录 requests = self._storage.get(key, []) # 清理过期记录 requests = [t for t in requests if t > window_start] # 检查是否超过限制 if len(requests) >= limit: self._storage[key] = requests return False # 记录新请求 requests.append(now) self._storage[key] = requests return True class IPWhitelist: """IP 白名单""" def __init__(self, allowed_ips: list = None): self.allowed_ips = set(allowed_ips or []) self.allow_all = '*' in self.allowed_ips def is_allowed(self, ip: str) -> bool: """检查 IP 是否允许""" if self.allow_all: return True # 支持 CIDR 格式 for allowed in self.allowed_ips: if '/' in allowed: if self._ip_in_cidr(ip, allowed): return True elif ip == allowed: return True return False def _ip_in_cidr(self, ip: str, cidr: str) -> bool: """检查 IP 是否在 CIDR 范围内""" try: import ipaddress network = ipaddress.ip_network(cidr, strict=False) address = ipaddress.ip_address(ip) return address in network except: return False