122 lines
3.1 KiB
Python
122 lines
3.1 KiB
Python
"""
|
|
安全工具
|
|
"""
|
|
import hashlib
|
|
import hmac
|
|
import secrets
|
|
import bcrypt
|
|
from typing import Optional
|
|
|
|
|
|
def generate_token(length: int = 32) -> str:
|
|
"""生成随机 Token"""
|
|
return secrets.token_urlsafe(length)
|
|
|
|
|
|
def hash_password(password: str) -> str:
|
|
"""密码哈希"""
|
|
return bcrypt.hashpw(
|
|
password.encode('utf-8'),
|
|
bcrypt.gensalt()
|
|
).decode('utf-8')
|
|
|
|
|
|
def verify_password(password: str, password_hash: str) -> bool:
|
|
"""验证密码"""
|
|
return bcrypt.checkpw(
|
|
password.encode('utf-8'),
|
|
password_hash.encode('utf-8')
|
|
)
|
|
|
|
|
|
def hash_token(token: str) -> str:
|
|
"""Token 哈希(用于存储)"""
|
|
return hashlib.sha256(token.encode()).hexdigest()
|
|
|
|
|
|
def verify_token_hash(token: str, token_hash: str) -> bool:
|
|
"""验证 Token 哈希"""
|
|
return hash_token(token) == token_hash
|
|
|
|
|
|
def generate_api_key() -> str:
|
|
"""生成 API Key"""
|
|
return f"pit_{secrets.token_urlsafe(32)}"
|
|
|
|
|
|
def secure_compare(val1: str, val2: str) -> bool:
|
|
"""安全字符串比较(防时序攻击)"""
|
|
return hmac.compare_digest(val1.encode(), val2.encode())
|
|
|
|
|
|
def mask_sensitive_data(data: str, visible_chars: int = 4) -> str:
|
|
"""脱敏处理"""
|
|
if len(data) <= visible_chars * 2:
|
|
return '*' * len(data)
|
|
|
|
return data[:visible_chars] + '*' * (len(data) - visible_chars * 2) + data[-visible_chars:]
|
|
|
|
|
|
class RateLimiter:
|
|
"""简单内存限流器"""
|
|
|
|
def __init__(self):
|
|
self._storage = {}
|
|
|
|
def is_allowed(self, key: str, limit: int = 100, window: int = 60) -> bool:
|
|
"""检查是否允许请求"""
|
|
from time import time
|
|
|
|
now = time()
|
|
window_start = now - window
|
|
|
|
# 获取该 key 的请求记录
|
|
requests = self._storage.get(key, [])
|
|
|
|
# 清理过期记录
|
|
requests = [t for t in requests if t > window_start]
|
|
|
|
# 检查是否超过限制
|
|
if len(requests) >= limit:
|
|
self._storage[key] = requests
|
|
return False
|
|
|
|
# 记录新请求
|
|
requests.append(now)
|
|
self._storage[key] = requests
|
|
|
|
return True
|
|
|
|
|
|
class IPWhitelist:
|
|
"""IP 白名单"""
|
|
|
|
def __init__(self, allowed_ips: list = None):
|
|
self.allowed_ips = set(allowed_ips or [])
|
|
self.allow_all = '*' in self.allowed_ips
|
|
|
|
def is_allowed(self, ip: str) -> bool:
|
|
"""检查 IP 是否允许"""
|
|
if self.allow_all:
|
|
return True
|
|
|
|
# 支持 CIDR 格式
|
|
for allowed in self.allowed_ips:
|
|
if '/' in allowed:
|
|
if self._ip_in_cidr(ip, allowed):
|
|
return True
|
|
elif ip == allowed:
|
|
return True
|
|
|
|
return False
|
|
|
|
def _ip_in_cidr(self, ip: str, cidr: str) -> bool:
|
|
"""检查 IP 是否在 CIDR 范围内"""
|
|
try:
|
|
import ipaddress
|
|
network = ipaddress.ip_network(cidr, strict=False)
|
|
address = ipaddress.ip_address(ip)
|
|
return address in network
|
|
except:
|
|
return False
|