Files
pit-router/app/utils/security.py

122 lines
3.1 KiB
Python

"""
安全工具
"""
import hashlib
import hmac
import secrets
import bcrypt
from typing import Optional
def generate_token(length: int = 32) -> str:
"""生成随机 Token"""
return secrets.token_urlsafe(length)
def hash_password(password: str) -> str:
"""密码哈希"""
return bcrypt.hashpw(
password.encode('utf-8'),
bcrypt.gensalt()
).decode('utf-8')
def verify_password(password: str, password_hash: str) -> bool:
"""验证密码"""
return bcrypt.checkpw(
password.encode('utf-8'),
password_hash.encode('utf-8')
)
def hash_token(token: str) -> str:
"""Token 哈希(用于存储)"""
return hashlib.sha256(token.encode()).hexdigest()
def verify_token_hash(token: str, token_hash: str) -> bool:
"""验证 Token 哈希"""
return hash_token(token) == token_hash
def generate_api_key() -> str:
"""生成 API Key"""
return f"pit_{secrets.token_urlsafe(32)}"
def secure_compare(val1: str, val2: str) -> bool:
"""安全字符串比较(防时序攻击)"""
return hmac.compare_digest(val1.encode(), val2.encode())
def mask_sensitive_data(data: str, visible_chars: int = 4) -> str:
"""脱敏处理"""
if len(data) <= visible_chars * 2:
return '*' * len(data)
return data[:visible_chars] + '*' * (len(data) - visible_chars * 2) + data[-visible_chars:]
class RateLimiter:
"""简单内存限流器"""
def __init__(self):
self._storage = {}
def is_allowed(self, key: str, limit: int = 100, window: int = 60) -> bool:
"""检查是否允许请求"""
from time import time
now = time()
window_start = now - window
# 获取该 key 的请求记录
requests = self._storage.get(key, [])
# 清理过期记录
requests = [t for t in requests if t > window_start]
# 检查是否超过限制
if len(requests) >= limit:
self._storage[key] = requests
return False
# 记录新请求
requests.append(now)
self._storage[key] = requests
return True
class IPWhitelist:
"""IP 白名单"""
def __init__(self, allowed_ips: list = None):
self.allowed_ips = set(allowed_ips or [])
self.allow_all = '*' in self.allowed_ips
def is_allowed(self, ip: str) -> bool:
"""检查 IP 是否允许"""
if self.allow_all:
return True
# 支持 CIDR 格式
for allowed in self.allowed_ips:
if '/' in allowed:
if self._ip_in_cidr(ip, allowed):
return True
elif ip == allowed:
return True
return False
def _ip_in_cidr(self, ip: str, cidr: str) -> bool:
"""检查 IP 是否在 CIDR 范围内"""
try:
import ipaddress
network = ipaddress.ip_network(cidr, strict=False)
address = ipaddress.ip_address(ip)
return address in network
except:
return False